Make it possible to join pseudo ID rooms (#3119)

This should allow us to join pseudo ID rooms over federation.
Also needs https://github.com/matrix-org/gomatrixserverlib/pull/397
This commit is contained in:
Till 2023-06-28 19:22:39 +02:00 committed by GitHub
parent 6697f4377d
commit 86f217f692
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 146 additions and 71 deletions

View file

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"crypto/ed25519"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -170,13 +171,24 @@ func (r *FederationInternalAPI) performJoinUsingServer(
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, },
SenderIDCreator: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (spec.SenderID, error) { GetOrCreateSenderID: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) {
// assign a roomNID, otherwise we can't create a private key for the user
_, nidErr := r.rsAPI.AssignRoomNID(ctx, roomID, gomatrixserverlib.RoomVersion(roomVersion))
if nidErr != nil {
return "", nil, nidErr
}
key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
if keyErr != nil { if keyErr != nil {
return "", keyErr return "", nil, keyErr
} }
return spec.SenderIDFromPseudoIDKey(key), key, nil
return spec.SenderIDFromPseudoIDKey(key), nil },
StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error {
storeUserID, userErr := spec.NewUserID(userIDRaw, true)
if userErr != nil {
return userErr
}
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
}, },
} }
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"sort" "sort"
@ -107,6 +108,10 @@ func MakeJoin(
} }
} }
if senderID == "" {
senderID = spec.SenderID(userID.String())
}
input := gomatrixserverlib.HandleMakeJoinInput{ input := gomatrixserverlib.HandleMakeJoinInput{
Context: httpReq.Context(), Context: httpReq.Context(),
UserID: userID, UserID: userID,
@ -218,6 +223,13 @@ func SendJoin(
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
}, },
StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error {
userID, userErr := spec.NewUserID(userIDRaw, true)
if userErr != nil {
return userErr
}
return rsAPI.StoreUserRoomPublicKey(ctx, senderID, *userID, roomID)
},
} }
response, joinErr := gomatrixserverlib.HandleSendJoin(input) response, joinErr := gomatrixserverlib.HandleSendJoin(input)
switch e := joinErr.(type) { switch e := joinErr.(type) {

4
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230616092350-99b78e30a272 github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16 github.com/mattn/go-sqlite3 v1.14.16
@ -43,6 +43,7 @@ require (
github.com/yggdrasil-network/yggdrasil-go v0.4.6 github.com/yggdrasil-network/yggdrasil-go v0.4.6
go.uber.org/atomic v1.10.0 go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.10.0 golang.org/x/crypto v0.10.0
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
golang.org/x/image v0.5.0 golang.org/x/image v0.5.0
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
golang.org/x/sync v0.1.0 golang.org/x/sync v0.1.0
@ -124,7 +125,6 @@ require (
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
go.etcd.io/bbolt v1.3.6 // indirect go.etcd.io/bbolt v1.3.6 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.9.0 // indirect golang.org/x/sys v0.9.0 // indirect

4
go.sum
View file

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230616092350-99b78e30a272 h1:YDpKHyojqwQU/L6I0bjp/BYPfn1aq/D2iBxC0jCSc/0= github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093 h1:FHd3SYhU2ZxZhkssZ/7ms5+M2j+g94lYp8ztvA1E6tA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230616092350-99b78e30a272/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20230628151943-f6e3c7f7b093/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=

View file

@ -74,6 +74,7 @@ type RoomserverInternalAPI interface {
type UserRoomPrivateKeyCreator interface { type UserRoomPrivateKeyCreator interface {
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
} }
type InputRoomEventsAPI interface { type InputRoomEventsAPI interface {
@ -235,7 +236,7 @@ type FederationRoomserverAPI interface {
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error

View file

@ -289,6 +289,15 @@ func (r *RoomserverInternalAPI) GetOrCreateUserRoomPrivateKey(ctx context.Contex
return key, nil return key, nil
} }
func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error {
pubKeyBytes, err := senderID.RawBytes()
if err != nil {
return err
}
_, err = r.DB.InsertUserRoomPublicKey(ctx, userID, roomID, ed25519.PublicKey(pubKeyBytes))
return err
}
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String()) roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
if !ok { if !ok {
@ -307,7 +316,7 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
} }
return fclient.SigningIdentity{ return fclient.SigningIdentity{
PrivateKey: privKey, PrivateKey: privKey,
KeyID: "ed25519", KeyID: "ed25519:1",
ServerName: "self", ServerName: "self",
}, nil }, nil
} }
@ -317,3 +326,7 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
} }
return *identity, err return *identity, err
} }
func (r *RoomserverInternalAPI) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) {
return r.DB.AssignRoomNID(ctx, roomID, roomVersion)
}

View file

@ -406,7 +406,7 @@ func (r *Inputer) processRoomEvent(
) )
if !isRejected && !isCreateEvent { if !isRejected && !isCreateEvent {
resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer) resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver) redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver, r.Queryer)
if err != nil { if err != nil {
return err return err
} }

View file

@ -647,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySe
resolver := state.NewStateResolution(db, roomInfo, querier) resolver := state.NewStateResolution(db, roomInfo, querier)
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver) _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver, querier)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
continue continue

View file

@ -196,7 +196,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
// sign all events with the pseudo ID key // sign all events with the pseudo ID key
identity = &fclient.SigningIdentity{ identity = &fclient.SigningIdentity{
ServerName: "self", ServerName: "self",
KeyID: "ed25519", KeyID: "ed25519:1",
PrivateKey: pseudoIDKey, PrivateKey: pseudoIDKey,
} }
} }

View file

@ -314,7 +314,7 @@ func (r *Joiner) performJoinRoomByID(
// sign the event with the pseudo ID key // sign the event with the pseudo ID key
identity = fclient.SigningIdentity{ identity = fclient.SigningIdentity{
ServerName: "self", ServerName: "self",
KeyID: "ed25519", KeyID: "ed25519:1",
PrivateKey: pseudoIDKey, PrivateKey: pseudoIDKey,
} }
} }

View file

@ -35,6 +35,14 @@ import (
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
) )
type FakeQuerier struct {
api.QuerySenderIDAPI
}
func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
func TestUsers(t *testing.T) { func TestUsers(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType) cfg, processCtx, close := testrig.CreateConfig(t, dbType)
@ -566,7 +574,7 @@ func TestRedaction(t *testing.T) {
err = updater.Commit() err = updater.Commit()
assert.NoError(t, err) assert.NoError(t, err)
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver) _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.PDU, &plResolver, &FakeQuerier{})
assert.NoError(t, err) assert.NoError(t, err)
if redactedEvent != nil { if redactedEvent != nil {
assert.Equal(t, ev.Redacts(), redactedEvent.EventID()) assert.Equal(t, ev.Redacts(), redactedEvent.EventID())

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
@ -190,7 +191,7 @@ type Database interface {
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
MaybeRedactEvent( MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
} }
@ -251,7 +252,7 @@ type EventDatabase interface {
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
// (nil if there was nothing to do) // (nil if there was nothing to do)
MaybeRedactEvent( MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) StoreEvent(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
} }

View file

@ -10,6 +10,7 @@ import (
"sort" "sort"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -991,6 +992,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) (
// Returns the redaction event and the redacted event if this call resulted in a redaction. // Returns the redaction event and the redacted event if this call resulted in a redaction.
func (d *EventDatabase) MaybeRedactEvent( func (d *EventDatabase) MaybeRedactEvent(
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver,
querier api.QuerySenderIDAPI,
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) { ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) {
var ( var (
redactionEvent, redactedEvent *types.Event redactionEvent, redactedEvent *types.Event
@ -1030,15 +1032,18 @@ func (d *EventDatabase) MaybeRedactEvent(
return nil return nil
} }
// TODO: Don't hack senderID into userID here (pseudoIDs) var validRoomID *spec.RoomID
validRoomID, err = spec.NewRoomID(redactedEvent.RoomID())
if err != nil {
return err
}
sender1Domain := "" sender1Domain := ""
sender1, err1 := spec.NewUserID(string(redactedEvent.SenderID()), true) sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID())
if err1 == nil { if err1 == nil {
sender1Domain = string(sender1.Domain()) sender1Domain = string(sender1.Domain())
} }
// TODO: Don't hack senderID into userID here (pseudoIDs)
sender2Domain := "" sender2Domain := ""
sender2, err2 := spec.NewUserID(string(redactionEvent.SenderID()), true) sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID())
if err2 == nil { if err2 == nil {
sender2Domain = string(sender2.Domain()) sender2Domain = string(sender2.Domain())
} }
@ -1757,7 +1762,6 @@ func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.User
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID // SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
result = make(map[spec.RoomID]map[string]string, len(publicKeys)) result = make(map[spec.RoomID]map[string]string, len(publicKeys))
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// map all roomIDs to roomNIDs // map all roomIDs to roomNIDs
query := make(map[types.RoomNID][]ed25519.PublicKey) query := make(map[types.RoomNID][]ed25519.PublicKey)
@ -1765,9 +1769,9 @@ func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys ma
for roomID, keys := range publicKeys { for roomID, keys := range publicKeys {
roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String()) roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String())
if !ok { if !ok {
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) roomInfo, rErr := d.roomInfo(ctx, nil, roomID.String())
if rErr != nil { if rErr != nil {
return rErr return nil, rErr
} }
if roomInfo == nil { if roomInfo == nil {
logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String()) logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String())
@ -1781,9 +1785,9 @@ func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys ma
} }
// get the user room key pars // get the user room key pars
userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query) userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, nil, query)
if sErr != nil { if sErr != nil {
return sErr return nil, sErr
} }
nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap)) nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap))
for _, nid := range userRoomKeyPairMap { for _, nid := range userRoomKeyPairMap {
@ -1792,7 +1796,7 @@ func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys ma
// get the userIDs // get the userIDs
nidMap, seErr := d.EventStateKeys(ctx, nids) nidMap, seErr := d.EventStateKeys(ctx, nids)
if seErr != nil { if seErr != nil {
return seErr return nil, seErr
} }
// build the result map (roomID -> map publicKey -> userID) // build the result map (roomID -> map publicKey -> userID)
@ -1806,9 +1810,6 @@ func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys ma
resMap[publicKey] = userID resMap[publicKey] = userID
result[roomID] = resMap result[roomID] = resMap
} }
return nil
})
return result, err return result, err
} }

View file

@ -57,6 +57,7 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
db *sql.DB
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
@ -70,7 +71,7 @@ func CreateUserRoomKeysTable(db *sql.DB) error {
} }
func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
s := &userRoomKeysStatements{} s := &userRoomKeysStatements{db: db}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
@ -137,7 +138,7 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1) selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1)
selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs
selectStmt, err := txn.Prepare(selectSQL) selectStmt, err := s.db.Prepare(selectSQL)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -350,6 +350,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// left. Anything that appears in the filtered timeline will be removed from the // left. Anything that appears in the filtered timeline will be removed from the
// "state" section and kept in "timeline". // "state" section and kept in "timeline".
// update the powerlevel event for timeline events
for i, ev := range events { for i, ev := range events {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue continue
@ -371,7 +372,17 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
) )
delta.StateEvents = make([]*rstypes.HeaderedEvent, len(sEvents)) delta.StateEvents = make([]*rstypes.HeaderedEvent, len(sEvents))
for i := range sEvents { for i := range sEvents {
delta.StateEvents[i] = sEvents[i].(*rstypes.HeaderedEvent) ev := sEvents[i]
delta.StateEvents[i] = ev.(*rstypes.HeaderedEvent)
// update the powerlevel event for state events
if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") {
var newEvent gomatrixserverlib.PDU
newEvent, err = p.updatePowerLevelEvent(ctx, ev.(*rstypes.HeaderedEvent))
if err != nil {
return r.From, err
}
delta.StateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}
} }
if len(delta.StateEvents) > 0 { if len(delta.StateEvents) > 0 {
@ -652,6 +663,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement() prevBatch.Decrement()
} }
// Update powerlevel events for timeline events
for i, ev := range events { for i, ev := range events {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue continue
@ -665,6 +677,20 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
} }
events[i] = &rstypes.HeaderedEvent{PDU: newEvent} events[i] = &rstypes.HeaderedEvent{PDU: newEvent}
} }
// Update powerlevel events for state events
for i, ev := range stateEvents {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
newEvent, err := p.updatePowerLevelEvent(ctx, ev)
if err != nil {
return nil, err
}
stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}
jr.Timeline.PrevBatch = prevBatch jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {