Compare commits

...

16 commits

Author SHA1 Message Date
Devon Hudson 2acc285f38
Fix linter errors 2024-01-03 17:39:08 -07:00
Devon Hudson 59753aba98
Refactor createRoom to remove duplicate code 2024-01-03 17:26:19 -07:00
Devon Hudson cbd547c828
Tweaks to federated invites 2023-12-19 16:33:38 -07:00
Devon Hudson 1dc00f5fd2
Update cryptoids to match join & create changes 2023-11-27 13:10:07 -07:00
Devon Hudson b45e72830e
Change cryptoid references from pseudoids 2023-11-17 17:34:01 -07:00
Devon Hudson 3cbccb9ed7
Update msc4080 endpoint names 2023-11-16 14:34:51 -07:00
Devon Hudson 5930a04044
Update send_pdus endpoint to reflect msc changes 2023-11-15 14:30:54 -07:00
Devon Hudson 227493cc5d
Handle cryptoID invites & leaves 2023-11-06 16:54:29 -07:00
Devon Hudson 7f7ac0f4fe
Use client pseudoid during create room 2023-11-02 10:56:25 -06:00
Devon Hudson b7d320f8d1
cryptoID changes 2023-10-31 16:52:58 -06:00
Devon Hudson 60be1391bf
Fix one time pseudoids in upload keys & sync endpoints 2023-10-25 16:23:46 -06:00
Devon Hudson 038103ac7f
Add one time pseudoids to upload keys & sync endpoints 2023-10-18 22:01:16 -06:00
Devon Hudson 29cd14baf5
Add /send & /state support for cryptoids 2023-10-13 17:01:18 -06:00
Devon Hudson 4bfcf27106
Fix local joins & make /rooms/roomid/join endpoint work 2023-10-13 10:12:28 -06:00
Devon Hudson f17de49c6b
CryptoID joins 2023-10-12 17:41:45 -06:00
Devon Hudson a5ba533cfb
Handle createRoom & sendPDUs 2023-10-04 11:29:09 -06:00
45 changed files with 3066 additions and 186 deletions

View file

@ -48,6 +48,7 @@ type createRoomRequest struct {
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"` PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"`
IsDirect bool `json:"is_direct"` IsDirect bool `json:"is_direct"`
CryptoID string `json:"cryptoid"`
} }
func (r createRoomRequest) Validate() *util.JSONResponse { func (r createRoomRequest) Validate() *util.JSONResponse {
@ -107,12 +108,27 @@ type createRoomResponse struct {
RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec
} }
type createRoomCryptoIDsResponse struct {
RoomID string `json:"room_id"`
Version string `json:"room_version"`
PDUs []json.RawMessage `json:"pdus"`
}
func ToProtoEvents(ctx context.Context, events []gomatrixserverlib.PDU, rsAPI roomserverAPI.ClientRoomserverAPI) []json.RawMessage {
result := make([]json.RawMessage, len(events))
for i, event := range events {
result[i] = json.RawMessage(event.JSON())
}
return result
}
// CreateRoom implements /createRoom // CreateRoom implements /createRoom
func CreateRoom( func CreateRoom(
req *http.Request, device *api.Device, req *http.Request, device *api.Device,
cfg *config.ClientAPI, cfg *config.ClientAPI,
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, asAPI appserviceAPI.AppServiceInternalAPI,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
var createRequest createRoomRequest var createRequest createRoomRequest
resErr := httputil.UnmarshalJSONRequest(req, &createRequest) resErr := httputil.UnmarshalJSONRequest(req, &createRequest)
@ -129,10 +145,9 @@ func CreateRoom(
JSON: spec.InvalidParam(err.Error()), JSON: spec.InvalidParam(err.Error()),
} }
} }
return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime) return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime, cryptoIDs)
} }
// createRoom implements /createRoom
func createRoom( func createRoom(
ctx context.Context, ctx context.Context,
createRequest createRoomRequest, device *api.Device, createRequest createRoomRequest, device *api.Device,
@ -140,6 +155,7 @@ func createRoom(
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI, profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, asAPI appserviceAPI.AppServiceInternalAPI,
evTime time.Time, evTime time.Time,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true) userID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
@ -225,6 +241,7 @@ func createRoom(
EventTime: evTime, EventTime: evTime,
} }
if !cryptoIDs {
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req) roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
if createRes != nil { if createRes != nil {
return *createRes return *createRes
@ -239,4 +256,27 @@ func createRoom(
Code: 200, Code: 200,
JSON: response, JSON: response,
} }
} else {
req.SenderID = createRequest.CryptoID
createEvents, err := rsAPI.PerformCreateRoomCryptoIDs(ctx, *userID, *roomID, &req)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("MakeCreateRoomEvents failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{Err: err.Error()},
}
}
response := createRoomCryptoIDsResponse{
RoomID: roomID.String(),
Version: string(roomVersion),
PDUs: ToProtoEvents(ctx, createEvents, rsAPI),
}
return util.JSONResponse{
Code: 200,
JSON: response,
}
}
} }

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
func JoinRoomByIDOrAlias( func JoinRoomByIDOrAlias(
@ -144,3 +145,139 @@ func JoinRoomByIDOrAlias(
return result return result
} }
} }
type joinRoomCryptoIDsResponse struct {
RoomID string `json:"room_id"`
Version string `json:"room_version"`
ViaServer string `json:"via_server"`
PDU json.RawMessage `json:"pdu"`
}
func JoinRoomByIDOrAliasCryptoIDs(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.ClientRoomserverAPI,
profileAPI api.ClientUserAPI,
roomIDOrAlias string,
) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join.
joinReq := roomserverAPI.PerformJoinRequest{
RoomIDOrAlias: roomIDOrAlias,
UserID: device.UserID,
IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{},
}
// Check to see if any ?server_name= query parameters were
// given in the request.
if serverNames, ok := req.URL.Query()["server_name"]; ok {
for _, serverName := range serverNames {
joinReq.ServerNames = append(
joinReq.ServerNames,
spec.ServerName(serverName),
)
}
}
// If content was provided in the request then include that
// in the request. It'll get used as a part of the membership
// event content.
_ = httputil.UnmarshalJSONRequest(req, &joinReq.Content)
if senderid, ok := joinReq.Content["cryptoid"]; ok {
logrus.Errorf("CryptoID: %s", senderid.(string))
joinReq.SenderID = spec.SenderID(senderid.(string))
} else {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("Missing cryptoid in request body"),
}
}
delete(joinReq.Content, "cryptoid")
// Work out our localpart for the client profile request.
// Request our profile content to populate the request content with.
profile, err := profileAPI.QueryProfile(req.Context(), device.UserID)
switch err {
case nil:
joinReq.Content["displayname"] = profile.DisplayName
joinReq.Content["avatar_url"] = profile.AvatarURL
case appserviceAPI.ErrProfileNotExists:
util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("Unable to query user profile, no profile found."),
}
default:
}
// Ask the roomserver to perform the join.
done := make(chan util.JSONResponse, 1)
go func() {
defer close(done)
joinEvent, roomID, version, serverName, err := rsAPI.PerformJoinCryptoIDs(req.Context(), &joinReq)
var response util.JSONResponse
switch e := err.(type) {
case nil: // success case
response = util.JSONResponse{
Code: http.StatusOK,
JSON: joinRoomCryptoIDsResponse{
RoomID: roomID,
Version: string(version),
ViaServer: string(serverName),
PDU: json.RawMessage(joinEvent.JSON()),
},
}
case roomserverAPI.ErrInvalidID:
response = util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown(e.Error()),
}
case roomserverAPI.ErrNotAllowed:
jsonErr := spec.Forbidden(e.Error())
if device.AccountType == api.AccountTypeGuest {
jsonErr = spec.GuestAccessForbidden(e.Error())
}
response = util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonErr,
}
case *gomatrix.HTTPError: // this ensures we proxy responses over federation to the client
response = util.JSONResponse{
Code: e.Code,
JSON: json.RawMessage(e.Message),
}
case eventutil.ErrRoomNoExists:
response = util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound(e.Error()),
}
default:
response = util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
done <- response
}()
// Wait either for the join to finish, or for us to hit a reasonable
// timeout, at which point we'll just return a 200 to placate clients.
timer := time.NewTimer(time.Second * 20)
select {
case <-timer.C:
return util.JSONResponse{
Code: http.StatusRequestTimeout,
JSON: spec.Unknown("Failed creating join event with the remote server."),
}
case result := <-done:
// Stop and drain the timer
if !timer.Stop() {
<-timer.C
}
return result
}
}

View file

@ -67,7 +67,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
Preset: spec.PresetPublicChat, Preset: spec.PresetPublicChat,
RoomAliasName: "alias", RoomAliasName: "alias",
Invite: []string{bob.ID}, Invite: []string{bob.ID},
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now(), false)
crResp, ok := resp.JSON.(createRoomResponse) crResp, ok := resp.JSON.(createRoomResponse)
if !ok { if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp) t.Fatalf("response is not a createRoomResponse: %+v", resp)
@ -81,7 +81,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
Visibility: "public", Visibility: "public",
Preset: spec.PresetPublicChat, Preset: spec.PresetPublicChat,
Invite: []string{charlie.ID}, Invite: []string{charlie.ID},
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) }, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now(), false)
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
if !ok { if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp) t.Fatalf("response is not a createRoomResponse: %+v", resp)

View file

@ -97,6 +97,92 @@ type queryKeysRequest struct {
DeviceKeys map[string][]string `json:"device_keys"` DeviceKeys map[string][]string `json:"device_keys"`
} }
type uploadKeysCryptoIDsRequest struct {
DeviceKeys json.RawMessage `json:"device_keys"`
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
OneTimeCryptoIDs map[string]json.RawMessage `json:"one_time_cryptoids"`
}
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
var r uploadKeysCryptoIDsRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
uploadReq := &api.PerformUploadKeysRequest{
DeviceID: device.ID,
UserID: device.UserID,
}
if r.DeviceKeys != nil {
uploadReq.DeviceKeys = []api.DeviceKeys{
{
DeviceID: device.ID,
UserID: device.UserID,
KeyJSON: r.DeviceKeys,
},
}
}
if r.OneTimeKeys != nil {
uploadReq.OneTimeKeys = []api.OneTimeKeys{
{
DeviceID: device.ID,
UserID: device.UserID,
KeyJSON: r.OneTimeKeys,
},
}
}
if r.OneTimeCryptoIDs != nil {
uploadReq.OneTimeCryptoIDs = []api.OneTimeCryptoIDs{
{
UserID: device.UserID,
KeyJSON: r.OneTimeCryptoIDs,
},
}
}
util.GetLogger(req.Context()).
WithField("device keys", r.DeviceKeys).
WithField("one-time keys", r.OneTimeKeys).
WithField("one-time cryptoids", r.OneTimeCryptoIDs).
Info("Uploading keys")
var uploadRes api.PerformUploadKeysResponse
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
return util.ErrorResponse(err)
}
if uploadRes.Error != nil {
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if len(uploadRes.KeyErrors) > 0 {
util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys")
return util.JSONResponse{
Code: 400,
JSON: uploadRes.KeyErrors,
}
}
keyCount := make(map[string]int)
if len(uploadRes.OneTimeKeyCounts) > 0 {
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
}
cryptoIDCount := make(map[string]int)
if len(uploadRes.OneTimeCryptoIDCounts) > 0 {
cryptoIDCount = uploadRes.OneTimeCryptoIDCounts[0].KeyCount
}
return util.JSONResponse{
Code: 200,
JSON: struct {
OTKCounts interface{} `json:"one_time_key_counts"`
OTIDCounts interface{} `json:"one_time_cryptoid_counts"`
}{keyCount, cryptoIDCount},
}
}
func (r *queryKeysRequest) GetTimeout() time.Duration { func (r *queryKeysRequest) GetTimeout() time.Duration {
if r.Timeout == 0 { if r.Timeout == 0 {
return 10 * time.Second return 10 * time.Second

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"encoding/json"
"net/http" "net/http"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
@ -23,11 +24,16 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type leaveRoomCryptoIDsResponse struct {
PDU json.RawMessage `json:"pdu"`
}
func LeaveRoomByID( func LeaveRoomByID(
req *http.Request, req *http.Request,
device *api.Device, device *api.Device,
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
roomID string, roomID string,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true) userID, err := spec.NewUserID(device.UserID, true)
if err != nil { if err != nil {
@ -45,7 +51,8 @@ func LeaveRoomByID(
leaveRes := roomserverAPI.PerformLeaveResponse{} leaveRes := roomserverAPI.PerformLeaveResponse{}
// Ask the roomserver to perform the leave. // Ask the roomserver to perform the leave.
if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil { leaveEvent, err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes, cryptoIDs)
if err != nil {
if leaveRes.Code != 0 { if leaveRes.Code != 0 {
return util.JSONResponse{ return util.JSONResponse{
Code: leaveRes.Code, Code: leaveRes.Code,
@ -58,8 +65,15 @@ func LeaveRoomByID(
} }
} }
if cryptoIDs {
return util.JSONResponse{
Code: http.StatusOK,
JSON: leaveRoomCryptoIDsResponse{
PDU: json.RawMessage(leaveEvent.JSON()),
},
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{},
} }
} }

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
@ -39,10 +40,15 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type membershipCryptoIDsResponse struct {
PDU json.RawMessage `json:"pdu"`
}
func SendBan( func SendBan(
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
body, evTime, reqErr := extractRequestData(req) body, evTime, reqErr := extractRequestData(req)
if reqErr != nil { if reqErr != nil {
@ -95,13 +101,14 @@ func SendBan(
} }
} }
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
} }
func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device, func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI) util.JSONResponse { rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, cryptoIDs bool) util.JSONResponse {
// TODO: cryptoIDs - what about when we don't know the senderID for a user?
event, err := buildMembershipEvent( event, err := buildMembershipEvent(
ctx, targetUserID, reason, profileAPI, device, membership, ctx, targetUserID, reason, profileAPI, device, membership,
roomID, false, cfg, evTime, rsAPI, asAPI, roomID, false, cfg, evTime, rsAPI, asAPI,
@ -114,6 +121,8 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
} }
} }
var json interface{}
if !cryptoIDs {
serverName := device.UserDomain() serverName := device.UserDomain()
if err = roomserverAPI.SendEvents( if err = roomserverAPI.SendEvents(
ctx, rsAPI, ctx, rsAPI,
@ -131,10 +140,14 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
json = struct{}{}
} else {
json = membershipCryptoIDsResponse{PDU: event.JSON()}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: json,
} }
} }
@ -142,6 +155,7 @@ func SendKick(
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
body, evTime, reqErr := extractRequestData(req) body, evTime, reqErr := extractRequestData(req)
if reqErr != nil { if reqErr != nil {
@ -216,13 +230,14 @@ func SendKick(
} }
} }
// TODO: should we be using SendLeave instead? // TODO: should we be using SendLeave instead?
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
} }
func SendUnban( func SendUnban(
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
body, evTime, reqErr := extractRequestData(req) body, evTime, reqErr := extractRequestData(req)
if reqErr != nil { if reqErr != nil {
@ -272,13 +287,14 @@ func SendUnban(
} }
} }
// TODO: should we be using SendLeave instead? // TODO: should we be using SendLeave instead?
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI) return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
} }
func SendInvite( func SendInvite(
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
body, evTime, reqErr := extractRequestData(req) body, evTime, reqErr := extractRequestData(req)
if reqErr != nil { if reqErr != nil {
@ -323,7 +339,7 @@ func SendInvite(
} }
// We already received the return value, so no need to check for an error here. // We already received the return value, so no need to check for an error here.
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime, cryptoIDs)
return response return response
} }
@ -336,6 +352,7 @@ func sendInvite(
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
cryptoIDs bool,
) (util.JSONResponse, error) { ) (util.JSONResponse, error) {
validRoomID, err := spec.NewRoomID(roomID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
@ -372,7 +389,7 @@ func sendInvite(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
}, err }, err
} }
err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ inviteEvent, err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
InviteInput: roomserverAPI.InviteInput{ InviteInput: roomserverAPI.InviteInput{
RoomID: *validRoomID, RoomID: *validRoomID,
Inviter: *inviter, Inviter: *inviter,
@ -387,7 +404,7 @@ func sendInvite(
}, },
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
SendAsServer: string(device.UserDomain()), SendAsServer: string(device.UserDomain()),
}) }, cryptoIDs)
switch e := err.(type) { switch e := err.(type) {
case roomserverAPI.ErrInvalidID: case roomserverAPI.ErrInvalidID:
@ -410,10 +427,22 @@ func sendInvite(
}, err }, err
} }
return util.JSONResponse{ response := util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: struct{}{},
}, nil }
type inviteCryptoIDResponse struct {
PDU json.RawMessage `json:"pdu"`
}
if inviteEvent != nil {
response.JSON = inviteCryptoIDResponse{
PDU: json.RawMessage(inviteEvent.JSON()),
}
}
return response, nil
} }
func buildMembershipEventDirect( func buildMembershipEventDirect(

View file

@ -309,9 +309,33 @@ func Setup(
v3mux.Handle("/createRoom", v3mux.Handle("/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI) return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI, false)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/createRoom",
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/createRoom")
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI, true)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}",
httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/send_pdus")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
txnID := vars["txnID"]
// NOTE: when making events such as for create_room, multiple PDUs will need to be passed between the client & server.
return SendPDUs(req, device, cfg, userAPI, rsAPI, asAPI, &txnID)
}),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/join/{roomIDOrAlias}", v3mux.Handle("/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
@ -334,8 +358,32 @@ func Setup(
return resp.(util.JSONResponse) return resp.(util.JSONResponse)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/join/{roomIDOrAlias}",
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/join/{roomIDOrAlias}")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
// Only execute a join for roomIDOrAlias and UserID once. If there is a join in progress
// it waits for it to complete and returns that result for subsequent requests.
resp, _, _ := sf.Do(vars["roomIDOrAlias"]+device.UserID, func() (any, error) {
return JoinRoomByIDOrAliasCryptoIDs(
req, device, rsAPI, userAPI, vars["roomIDOrAlias"],
), nil
})
// once all joins are processed, drop them from the cache. Further requests
// will be processed as usual.
sf.Forget(vars["roomIDOrAlias"] + device.UserID)
return resp.(util.JSONResponse)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
if mscCfg.Enabled("msc2753") { if mscCfg.Enabled("msc2753") {
// TODO: update for cryptoIDs
v3mux.Handle("/peek/{roomIDOrAlias}", v3mux.Handle("/peek/{roomIDOrAlias}",
httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
@ -378,6 +426,29 @@ func Setup(
return resp.(util.JSONResponse) return resp.(util.JSONResponse)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/join",
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/join")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
// Only execute a join for roomID and UserID once. If there is a join in progress
// it waits for it to complete and returns that result for subsequent requests.
resp, _, _ := sf.Do(vars["roomID"]+device.UserID, func() (any, error) {
return JoinRoomByIDOrAliasCryptoIDs(
req, device, rsAPI, userAPI, vars["roomID"],
), nil
})
// once all joins are processed, drop them from the cache. Further requests
// will be processed as usual.
sf.Forget(vars["roomID"] + device.UserID)
return resp.(util.JSONResponse)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/leave", v3mux.Handle("/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
@ -388,10 +459,26 @@ func Setup(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return LeaveRoomByID( return LeaveRoomByID(
req, device, rsAPI, vars["roomID"], req, device, rsAPI, vars["roomID"], false,
) )
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/leave",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/leave")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return LeaveRoomByID(
req, device, rsAPI, vars["roomID"], true,
)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/unpeek", v3mux.Handle("/rooms/{roomID}/unpeek",
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -409,7 +496,17 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/ban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/ban")
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/invite", v3mux.Handle("/rooms/{roomID}/invite",
@ -421,7 +518,20 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/invite")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/kick", v3mux.Handle("/rooms/{roomID}/kick",
@ -430,7 +540,17 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/kick",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/kick")
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/unban", v3mux.Handle("/rooms/{roomID}/unban",
@ -439,7 +559,17 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/unban",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/unban")
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/send/{eventType}", v3mux.Handle("/rooms/{roomID}/send/{eventType}",
@ -451,6 +581,16 @@ func Setup(
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/send/{eventType}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/send/{eventType}")
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -462,6 +602,18 @@ func Setup(
nil, cfg, rsAPI, transactionsCache) nil, cfg, rsAPI, transactionsCache)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/send/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/send/{eventType}/{txnID}")
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
txnID := vars["txnID"]
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], &txnID,
nil, cfg, rsAPI, transactionsCache)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -510,6 +662,18 @@ func Setup(
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/state/{eventType:[^/]+/?}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/state/{eventType:[^/]+/?}")
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
emptyString := ""
eventType := strings.TrimSuffix(vars["eventType"], "/")
return SendEventCryptoIDs(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
@ -521,6 +685,17 @@ func Setup(
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/state/{eventType}/{stateKey}",
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/state/{eventType}/{stateKey}")
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
stateKey := vars["stateKey"]
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPut, http.MethodOptions)
// Defined outside of handler to persist between calls // Defined outside of handler to persist between calls
// TODO: clear based on some criteria // TODO: clear based on some criteria
@ -559,6 +734,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -569,6 +745,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/directory/room/{roomAlias}", v3mux.Handle("/directory/room/{roomAlias}",
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -636,6 +813,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/typing/{userID}", v3mux.Handle("/rooms/{roomID}/typing/{userID}",
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
@ -648,6 +826,7 @@ func Setup(
return SendTyping(req, device, vars["roomID"], vars["userID"], rsAPI, syncProducer) return SendTyping(req, device, vars["roomID"], vars["userID"], rsAPI, syncProducer)
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/redact/{eventID}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -657,6 +836,7 @@ func Setup(
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil) return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -668,6 +848,7 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/sendToDevice/{eventType}/{txnID}", v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1118,6 +1299,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/read_markers", v3mux.Handle("/rooms/{roomID}/read_markers",
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
@ -1144,6 +1326,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/upgrade", v3mux.Handle("/rooms/{roomID}/upgrade",
httputil.MakeAuthAPI("rooms_upgrade", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_upgrade", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
@ -1472,6 +1655,12 @@ func Setup(
return UploadKeys(req, userAPI, device) return UploadKeys(req, userAPI, device)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/keys/upload")
return UploadKeysCryptoIDs(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query", v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, userAPI, device) return QueryKeys(req, userAPI, device)
@ -1495,6 +1684,7 @@ func Setup(
return SetReceipt(req, userAPI, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"]) return SetReceipt(req, userAPI, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/presence/{userId}/status", v3mux.Handle("/presence/{userId}/status",
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))

View file

@ -0,0 +1,272 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"encoding/json"
"net/http"
"sync"
"time"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/httputil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
)
type PDUInfo struct {
Version string `json:"room_version"`
ViaServer string `json:"via_server,omitempty"`
PDU json.RawMessage `json:"pdu"`
}
type sendPDUsRequest struct {
PDUs []PDUInfo `json:"pdus"`
}
// SendPDUs implements /sendPDUs
// nolint:gocyclo
func SendPDUs(
req *http.Request, device *api.Device,
cfg *config.ClientAPI,
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI,
txnID *string,
) util.JSONResponse {
// TODO: cryptoIDs - should this include an "eventType"?
// if it's a bulk send endpoint, I don't think that makes any sense since there are multiple event types
// In that case, how do I know how to treat the events?
// I could sort them all by roomID?
// Then filter them down based on event type? (how do I collect groups of events such as for room creation?)
// Possibly based on event hash tracking that I know were sent to the client?
// For createRoom, I know what the possible list of events are, so try to find those and collect them to send to room creation.
// Could also sort by depth... but that seems dangerous and depth may not be a field forever
// Does it matter at all?
// Can't I just forward all the events to the roomserver?
// Do I need to do any specific processing on them?
var pdus sendPDUsRequest
resErr := httputil.UnmarshalJSONRequest(req, &pdus)
if resErr != nil {
return *resErr
}
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(err.Error()),
}
}
// create a mutex for the specific user in the specific room
// this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order
// TODO: cryptoIDs - where to get roomID from?
mutex, _ := userRoomSendMutexes.LoadOrStore("roomID"+userID.String(), &sync.Mutex{})
mutex.(*sync.Mutex).Lock()
defer mutex.(*sync.Mutex).Unlock()
inputs := make([]roomserverAPI.InputRoomEvent, 0, len(pdus.PDUs))
for _, event := range pdus.PDUs {
// TODO: cryptoIDs - event hash check?
verImpl, err := gomatrixserverlib.GetRoomVersion(gomatrixserverlib.RoomVersion(event.Version))
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{Err: err.Error()},
}
}
//eventJSON, err := json.Marshal(event)
//if err != nil {
// return util.JSONResponse{
// Code: http.StatusInternalServerError,
// JSON: spec.InternalServerError{Err: err.Error()},
// }
//}
// TODO: cryptoIDs - how should we be converting to a PDU here?
// if the hash matches an event we sent to the client, then the JSON should be good.
// But how do we know how to fill out if the event is redacted if we use the trustedJSON function?
// Also - untrusted JSON seems better - except it strips off the unsigned field?
// Also - gmsl events don't store the `hashes` field... problem?
pdu, err := verImpl.NewEventFromUntrustedJSON(event.PDU)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{Err: err.Error()},
}
}
util.GetLogger(req.Context()).Infof("Processing %s event (%s): %s", pdu.Type(), pdu.EventID(), pdu.JSON())
// Check that the event is signed by the server sending the request.
redacted, err := verImpl.RedactEventJSON(pdu.JSON())
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("RedactEventJSON failed")
continue
}
verifier := gomatrixserverlib.JSONVerifierSelf{}
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: spec.ServerName(pdu.SenderID()),
Message: redacted,
AtTS: pdu.OriginServerTS(),
ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck,
}}
verifyResults, err := verifier.VerifyJSONs(req.Context(), verifyRequests)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("keys.VerifyJSONs failed")
continue
}
if verifyResults[0].Error != nil {
util.GetLogger(req.Context()).WithError(verifyResults[0].Error).Error("Signature check failed: ")
continue
}
switch pdu.Type() {
case spec.MRoomCreate:
case spec.MRoomMember:
var membership gomatrixserverlib.MemberContent
err = json.Unmarshal(pdu.Content(), &membership)
switch {
case err != nil:
util.GetLogger(req.Context()).Errorf("m.room.member event (%s) content invalid: %v", pdu.EventID(), pdu.Content())
continue
case membership.Membership == spec.Join:
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
}
}
if !cfg.Matrix.IsLocalServerName(pdu.RoomID().Domain()) {
queryReq := roomserverAPI.QueryMembershipForUserRequest{
RoomID: pdu.RoomID().String(),
UserID: *deviceUserID,
}
var queryRes roomserverAPI.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if !queryRes.IsInRoom {
// This is a join event to a remote room
// TODO: cryptoIDs - figure out how to obtain unsigned contents for outstanding federated invites
joinReq := roomserverAPI.PerformJoinRequestCryptoIDs{
RoomID: pdu.RoomID().String(),
UserID: device.UserID,
IsGuest: device.AccountType == api.AccountTypeGuest,
ServerNames: []spec.ServerName{spec.ServerName(event.ViaServer)},
JoinEvent: pdu,
}
err := rsAPI.PerformSendJoinCryptoIDs(req.Context(), &joinReq)
if err != nil {
util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err)
}
continue // NOTE: don't send this event on to the roomserver
}
}
case membership.Membership == spec.Invite:
stateKey := pdu.StateKey()
if stateKey == nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("invalid state_key for membership event"),
}
}
invitedUserID, err := rsAPI.QueryUserIDForSender(req.Context(), pdu.RoomID(), spec.SenderID(*stateKey))
if err != nil || invitedUserID == nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: spec.NotFound("cannot find userID for invite event"),
}
}
if !cfg.Matrix.IsLocalServerName(spec.ServerName(invitedUserID.Domain())) {
inviteReq := roomserverAPI.PerformInviteRequestCryptoIDs{
RoomID: pdu.RoomID().String(),
UserID: *invitedUserID,
InviteEvent: pdu,
}
err := rsAPI.PerformSendInviteCryptoIDs(req.Context(), &inviteReq)
if err != nil {
util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err)
}
continue // NOTE: don't send this event on to the roomserver
}
}
}
// TODO: cryptoIDs - does it matter which order these are added?
// yes - if the events are for room creation.
// Could make this a client requirement? ie. events are processed based on the order they appear
// We need to check event validity after processing each event.
// ie. what if the client changes power levels that disallow further events they sent?
// We should be doing this already as part of `SendInputRoomEvents`, but how should we pass this
// failure back to the client?
var transactionID *roomserverAPI.TransactionID
if txnID != nil {
transactionID = &roomserverAPI.TransactionID{
SessionID: device.SessionID, TransactionID: *txnID,
}
}
inputs = append(inputs, roomserverAPI.InputRoomEvent{
Kind: roomserverAPI.KindNew,
Event: &types.HeaderedEvent{PDU: pdu},
Origin: userID.Domain(),
// TODO: cryptoIDs - what to do with this field?
// should probably generate this based on the event type being sent?
//SendAsServer: roomserverAPI.DoNotSendToOtherServers,
TransactionID: transactionID,
})
}
startedSubmittingEvents := time.Now()
// send the events to the roomserver
if err := roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, userID.Domain(), inputs, false); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{Err: err.Error()},
}
}
timeToSubmitEvents := time.Since(startedSubmittingEvents)
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvents.Milliseconds()))
// Add response to transactionsCache
if txnID != nil {
// TODO : cryptoIDs - fix this
//res := util.JSONResponse{
// Code: http.StatusOK,
// JSON: sendEventResponse{e.EventID()},
//}
//txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
}
return util.JSONResponse{
Code: http.StatusOK,
}
}

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/synctypes"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -44,6 +45,11 @@ type sendEventResponse struct {
EventID string `json:"event_id"` EventID string `json:"event_id"`
} }
type sendEventResponseCryptoIDs struct {
EventID string `json:"event_id"`
PDU json.RawMessage `json:"pdu"`
}
var ( var (
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
) )
@ -94,7 +100,7 @@ func SendEvent(
} }
// Translate user ID state keys to room keys in pseudo ID rooms // Translate user ID state keys to room keys in pseudo ID rooms
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil { if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
parsedRoomID, innerErr := spec.NewRoomID(roomID) parsedRoomID, innerErr := spec.NewRoomID(roomID)
if innerErr != nil { if innerErr != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -148,7 +154,7 @@ func SendEvent(
} }
// for power level events we need to replace the userID with the pseudoID // for power level events we need to replace the userID with the pseudoID
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels { if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
err = updatePowerLevels(req, r, roomID, rsAPI) err = updatePowerLevels(req, r, roomID, rsAPI)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -166,7 +172,7 @@ func SendEvent(
} }
} }
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime) e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime, false)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
} }
@ -262,6 +268,156 @@ func SendEvent(
return res return res
} }
// SendEventCryptoIDs implements:
//
// /rooms/{roomID}/send/{eventType}
// /rooms/{roomID}/send/{eventType}/{txnID}
// /rooms/{roomID}/state/{eventType}/{stateKey}
//
// nolint: gocyclo
func SendEventCryptoIDs(
req *http.Request,
device *userapi.Device,
roomID, eventType string, txnID, stateKey *string,
cfg *config.ClientAPI,
rsAPI api.ClientRoomserverAPI,
txnCache *transactions.Cache,
) util.JSONResponse {
roomVersion, err := rsAPI.QueryRoomVersionForRoom(req.Context(), roomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.UnsupportedRoomVersion(err.Error()),
}
}
if txnID != nil {
// Try to fetch response from transactionsCache
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
return *res
}
}
// Translate user ID state keys to room keys in pseudo ID rooms
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
parsedRoomID, innerErr := spec.NewRoomID(roomID)
if innerErr != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid room ID"),
}
}
newStateKey, innerErr := synctypes.FromClientStateKey(*parsedRoomID, *stateKey, func(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
return rsAPI.QuerySenderIDForUser(req.Context(), roomID, userID)
})
if innerErr != nil {
// TODO: work out better logic for failure cases (e.g. sender ID not found)
util.GetLogger(req.Context()).WithError(innerErr).Error("synctypes.FromClientStateKey failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
stateKey = newStateKey
}
var r map[string]interface{} // must be a JSON object
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
if stateKey != nil {
// If the existing/new state content are equal, return the existing event_id, making the request idempotent.
if resp := stateEqual(req.Context(), rsAPI, eventType, *stateKey, roomID, r); resp != nil {
return *resp
}
}
startedGeneratingEvent := time.Now()
// If we're sending a membership update, make sure to strip the authorised
// via key if it is present, otherwise other servers won't be able to auth
// the event if the room is set to the "restricted" join rule.
if eventType == spec.MRoomMember {
delete(r, "join_authorised_via_users_server")
}
// for power level events we need to replace the userID with the pseudoID
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
err = updatePowerLevels(req, r, roomID, rsAPI)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{Err: err.Error()},
}
}
}
evTime, err := httputil.ParseTSParam(req)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam(err.Error()),
}
}
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime, true)
if resErr != nil {
return *resErr
}
timeToGenerateEvent := time.Since(startedGeneratingEvent)
// validate that the aliases exists
if eventType == spec.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" {
aliasReq := api.AliasEvent{}
if err = json.Unmarshal(e.Content(), &aliasReq); err != nil {
return util.ErrorResponse(fmt.Errorf("unable to parse alias event: %w", err))
}
if !aliasReq.Valid() {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Request contains invalid aliases."),
}
}
aliasRes := &api.GetAliasesForRoomIDResponse{}
if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
var found int
requestAliases := append(aliasReq.AltAliases, aliasReq.Alias)
for _, alias := range aliasRes.Aliases {
for _, altAlias := range requestAliases {
if altAlias == alias {
found++
}
}
}
// check that we found at least the same amount of existing aliases as are in the request
if aliasReq.Alias != "" && found < len(requestAliases) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadAlias("No matching alias found."),
}
}
}
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
res := util.JSONResponse{
Code: http.StatusOK,
JSON: sendEventResponseCryptoIDs{
EventID: e.EventID(),
PDU: json.RawMessage(e.JSON()),
},
}
return res
}
func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error { func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error {
users, ok := r["users"] users, ok := r["users"]
if !ok { if !ok {
@ -329,6 +485,7 @@ func generateSendEvent(
roomID, eventType string, stateKey *string, roomID, eventType string, stateKey *string,
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
evTime time.Time, evTime time.Time,
cryptoIDs bool,
) (gomatrixserverlib.PDU, *util.JSONResponse) { ) (gomatrixserverlib.PDU, *util.JSONResponse) {
// parse the incoming http request // parse the incoming http request
fullUserID, err := spec.NewUserID(device.UserID, true) fullUserID, err := spec.NewUserID(device.UserID, true)
@ -376,13 +533,19 @@ func generateSendEvent(
} }
} }
identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID) var identity fclient.SigningIdentity
if err != nil { if !cryptoIDs {
id, idErr := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID)
if idErr != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
} }
} }
identity = id
} else {
identity.ServerName = spec.ServerName(*senderID)
}
var queryRes api.QueryLatestEventsAndStateResponse var queryRes api.QueryLatestEventsAndStateResponse
e, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, &queryRes) e, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, &queryRes)

View file

@ -169,7 +169,7 @@ func SendServerNotice(
PowerLevelContentOverride: pl, PowerLevelContentOverride: pl,
} }
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, userAPI, rsAPI, asAPI, time.Now()) roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, userAPI, rsAPI, asAPI, time.Now(), false)
switch data := roomRes.JSON.(type) { switch data := roomRes.JSON.(type) {
case createRoomResponse: case createRoomResponse:
@ -215,7 +215,7 @@ func SendServerNotice(
} }
if !membershipRes.IsInRoom { if !membershipRes.IsInRoom {
// re-invite the user // re-invite the user
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now(), false)
if err != nil { if err != nil {
return res return res
} }
@ -228,7 +228,7 @@ func SendServerNotice(
"body": r.Content.Body, "body": r.Content.Body,
"msgtype": r.Content.MsgType, "msgtype": r.Content.MsgType,
} }
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now()) e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now(), false)
if resErr != nil { if resErr != nil {
logrus.Errorf("failed to send message: %+v", resErr) logrus.Errorf("failed to send message: %+v", resErr)
return *resErr return *resErr

View file

@ -198,6 +198,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// state to see if there is an event with that type and state key, if there // state to see if there is an event with that type and state key, if there
// is then (by default) we return the content, otherwise a 404. // is then (by default) we return the content, otherwise a 404.
// If eventFormat=true, sends the whole event else just the content. // If eventFormat=true, sends the whole event else just the content.
// nolint:gocyclo
func OnIncomingStateTypeRequest( func OnIncomingStateTypeRequest(
ctx context.Context, device *userapi.Device, rsAPI api.ClientRoomserverAPI, ctx context.Context, device *userapi.Device, rsAPI api.ClientRoomserverAPI,
roomID, evType, stateKey string, eventFormat bool, roomID, evType, stateKey string, eventFormat bool,
@ -214,7 +215,7 @@ func OnIncomingStateTypeRequest(
} }
// Translate user ID state keys to room keys in pseudo ID rooms // Translate user ID state keys to room keys in pseudo ID rooms
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs { if roomVer == gomatrixserverlib.RoomVersionPseudoIDs || roomVer == gomatrixserverlib.RoomVersionCryptoIDs {
parsedRoomID, err := spec.NewRoomID(roomID) parsedRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -58,12 +58,18 @@ type RoomserverFederationAPI interface {
PerformDirectoryLookup(ctx context.Context, request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse) error PerformDirectoryLookup(ctx context.Context, request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse) error
// Handle an instruction to make_join & send_join with a remote server. // Handle an instruction to make_join & send_join with a remote server.
PerformJoin(ctx context.Context, request *PerformJoinRequest, response *PerformJoinResponse) PerformJoin(ctx context.Context, request *PerformJoinRequest, response *PerformJoinResponse)
PerformMakeJoin(ctx context.Context, request *PerformJoinRequest) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error)
PerformSendJoin(ctx context.Context, request *PerformSendJoinRequestCryptoIDs, response *PerformJoinResponse)
// Handle an instruction to make_leave & send_leave with a remote server. // Handle an instruction to make_leave & send_leave with a remote server.
PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse, cryptoIDs bool) error
// Handle sending an invite to a remote server. // Handle sending an invite to a remote server.
SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
// Handle sending an invite to a remote server. // Handle sending an invite to a remote server.
SendInviteV3(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) SendInviteV3(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
// Handle sending an invite to a remote server.
MakeInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
// Handle sending an invite to a remote server.
SendInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.PDU, invitee spec.UserID, version gomatrixserverlib.RoomVersion) error
// Handle an instruction to peek a room on a remote server. // Handle an instruction to peek a room on a remote server.
PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error
// Query the server names of the joined hosts in a room. // Query the server names of the joined hosts in a room.
@ -168,6 +174,15 @@ type PerformJoinRequest struct {
Unsigned map[string]interface{} `json:"unsigned"` Unsigned map[string]interface{} `json:"unsigned"`
} }
type PerformSendJoinRequestCryptoIDs struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
ServerNames types.ServerNames `json:"server_names"`
Unsigned map[string]interface{} `json:"unsigned"`
Event gomatrixserverlib.PDU
}
type PerformJoinResponse struct { type PerformJoinResponse struct {
JoinedVia spec.ServerName JoinedVia spec.ServerName
LastError *gomatrix.HTTPError LastError *gomatrix.HTTPError

View file

@ -239,6 +239,319 @@ func (r *FederationInternalAPI) performJoinUsingServer(
return nil return nil
} }
// PerformMakeJoin implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformMakeJoin(
ctx context.Context,
request *api.PerformJoinRequest,
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
// Check that a join isn't already in progress for this user/room.
j := federatedJoin{request.UserID, request.RoomID}
if _, found := r.joins.Load(j); found {
return nil, "", "", &gomatrix.HTTPError{
Code: 429,
Message: `{
"errcode": "M_LIMIT_EXCEEDED",
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
}`, // TODO: Why do none of our error types play nicely with each other?
}
}
r.joins.Store(j, nil)
defer r.joins.Delete(j)
// Deduplicate the server names we were provided but keep the ordering
// as this encodes useful information about which servers are most likely
// to respond.
seenSet := make(map[spec.ServerName]bool)
var uniqueList []spec.ServerName
for _, srv := range request.ServerNames {
if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) {
continue
}
seenSet[srv] = true
uniqueList = append(uniqueList, srv)
}
request.ServerNames = uniqueList
// Try each server that we were provided until we land on one that
// successfully completes the make-join send-join dance.
var lastErr error
for _, serverName := range request.ServerNames {
var joinEvent gomatrixserverlib.PDU
var roomVersion gomatrixserverlib.RoomVersion
var err error
if joinEvent, roomVersion, _, err = r.performMakeJoinUsingServer(
ctx,
request.RoomID,
request.UserID,
request.Content,
serverName,
); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"server_name": serverName,
"room_id": request.RoomID,
}).Warnf("Failed to join room through server")
lastErr = err
continue
}
// We're all good.
return joinEvent, roomVersion, serverName, err
}
// If we reach here then we didn't complete a join for some reason.
var httpErr gomatrix.HTTPError
var lastError *gomatrix.HTTPError
if ok := errors.As(lastErr, &httpErr); ok {
httpErr.Message = string(httpErr.Contents)
lastError = &httpErr
} else {
lastError = &gomatrix.HTTPError{
Code: 0,
WrappedError: nil,
Message: "Unknown HTTP error",
}
if lastErr != nil {
lastError.Message = lastErr.Error()
}
}
logrus.Errorf(
"failed to join user %q to room %q through %d server(s): last error %s",
request.UserID, request.RoomID, len(request.ServerNames), lastError,
)
return nil, "", "", lastError
}
func (r *FederationInternalAPI) performMakeJoinUsingServer(
ctx context.Context,
roomID, userID string,
content map[string]interface{},
serverName spec.ServerName,
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.SenderID, error) {
if !r.shouldAttemptDirectFederation(serverName) {
return nil, "", "", fmt.Errorf("relay servers have no meaningful response for join.")
}
user, err := spec.NewUserID(userID, true)
if err != nil {
return nil, "", "", err
}
room, err := spec.NewRoomID(roomID)
if err != nil {
return nil, "", "", err
}
joinInput := gomatrixserverlib.PerformMakeJoinInput{
UserID: user,
RoomID: room,
ServerName: serverName,
Content: content,
PrivateKey: r.cfg.Matrix.PrivateKey,
KeyID: r.cfg.Matrix.KeyID,
KeyRing: r.keyRing,
GetOrCreateSenderID: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) {
// assign a roomNID, otherwise we can't create a private key for the user
_, nidErr := r.rsAPI.AssignRoomNID(ctx, roomID, gomatrixserverlib.RoomVersion(roomVersion))
if nidErr != nil {
return "", nil, nidErr
}
key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
if keyErr != nil {
return "", nil, keyErr
}
return spec.SenderIDFromPseudoIDKey(key), key, nil
},
}
joinEvent, version, senderID, joinErr := gomatrixserverlib.PerformMakeJoin(ctx, r, joinInput)
if joinErr != nil {
if !joinErr.Reachable {
r.statistics.ForServer(joinErr.ServerName).Failure()
} else {
r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect)
}
return nil, "", "", joinErr.Err
}
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
if joinEvent == nil {
return nil, "", "", fmt.Errorf("Received nil joinEvent response from gomatrixserverlib.PerformJoin")
}
return joinEvent, version, senderID, nil
}
// PerformSendJoin implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformSendJoin(
ctx context.Context,
request *api.PerformSendJoinRequestCryptoIDs,
response *api.PerformJoinResponse,
) {
// Check that a join isn't already in progress for this user/room.
j := federatedJoin{request.UserID, request.RoomID}
if _, found := r.joins.Load(j); found {
response.LastError = &gomatrix.HTTPError{
Code: 429,
Message: `{
"errcode": "M_LIMIT_EXCEEDED",
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
}`,
}
return
}
r.joins.Store(j, nil)
defer r.joins.Delete(j)
// Deduplicate the server names we were provided but keep the ordering
// as this encodes useful information about which servers are most likely
// to respond.
seenSet := make(map[spec.ServerName]bool)
var uniqueList []spec.ServerName
for _, srv := range request.ServerNames {
if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) {
continue
}
seenSet[srv] = true
uniqueList = append(uniqueList, srv)
}
request.ServerNames = uniqueList
// Try each server that we were provided until we land on one that
// successfully completes the make-join send-join dance.
var lastErr error
for _, serverName := range request.ServerNames {
if err := r.performSendJoinUsingServer(
ctx,
request.RoomID,
request.UserID,
request.Unsigned,
request.Event,
serverName,
); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"server_name": serverName,
"room_id": request.RoomID,
}).Warnf("Failed to join room through server")
lastErr = err
continue
}
// We're all good.
return
}
// If we reach here then we didn't complete a join for some reason.
var httpErr gomatrix.HTTPError
var lastError *gomatrix.HTTPError
if ok := errors.As(lastErr, &httpErr); ok {
httpErr.Message = string(httpErr.Contents)
lastError = &httpErr
} else {
lastError = &gomatrix.HTTPError{
Code: 0,
WrappedError: nil,
Message: "Unknown HTTP error",
}
if lastErr != nil {
lastError.Message = lastErr.Error()
}
}
logrus.Errorf(
"failed to join user %q to room %q through %d server(s): last error %s",
request.UserID, request.RoomID, len(request.ServerNames), lastError,
)
}
func (r *FederationInternalAPI) performSendJoinUsingServer(
ctx context.Context,
roomID, userID string,
unsigned map[string]interface{},
event gomatrixserverlib.PDU,
serverName spec.ServerName,
) error {
user, err := spec.NewUserID(userID, true)
if err != nil {
return err
}
room, err := spec.NewRoomID(roomID)
if err != nil {
return err
}
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *room, *user)
if err != nil {
return err
}
joinInput := gomatrixserverlib.PerformSendJoinInput{
RoomID: room,
ServerName: serverName,
Unsigned: unsigned,
Origin: user.Domain(),
SenderID: *senderID,
KeyRing: r.keyRing,
Event: event,
RoomVersion: event.Version(),
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}),
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
},
StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error {
storeUserID, userErr := spec.NewUserID(userIDRaw, true)
if userErr != nil {
return userErr
}
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
},
}
response, joinErr := gomatrixserverlib.PerformSendJoin(ctx, r, joinInput)
if joinErr != nil {
if !joinErr.Reachable {
r.statistics.ForServer(joinErr.ServerName).Failure()
} else {
r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect)
}
return joinErr.Err
}
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
if response == nil {
return fmt.Errorf("Received nil response from gomatrixserverlib.PerformSendJoin")
}
// We need to immediately update our list of joined hosts for this room now as we are technically
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
// The events are trusted now as we performed auth checks above.
joinedHosts, err := consumers.JoinedHostsFromEvents(ctx, response.StateSnapshot.GetStateEvents().TrustedEvents(response.JoinEvent.Version(), false), r.rsAPI)
if err != nil {
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
}
logrus.WithField("room", roomID).Infof("Joined federated room with %d hosts", len(joinedHosts))
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
}
// TODO: Can I change this to not take respState but instead just take an opaque list of events?
if err = roomserverAPI.SendEventWithState(
context.Background(),
r.rsAPI,
user.Domain(),
roomserverAPI.KindNew,
response.StateSnapshot,
&types.HeaderedEvent{PDU: response.JoinEvent},
serverName,
nil,
false,
); err != nil {
return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err)
}
return nil
}
// PerformOutboundPeekRequest implements api.FederationInternalAPI // PerformOutboundPeekRequest implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformOutboundPeek( func (r *FederationInternalAPI) PerformOutboundPeek(
ctx context.Context, ctx context.Context,
@ -433,6 +746,7 @@ func (r *FederationInternalAPI) PerformLeave(
ctx context.Context, ctx context.Context,
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse, response *api.PerformLeaveResponse,
cryptoIDs bool,
) (err error) { ) (err error) {
userID, err := spec.NewUserID(request.UserID, true) userID, err := spec.NewUserID(request.UserID, true)
if err != nil { if err != nil {
@ -649,6 +963,87 @@ func (r *FederationInternalAPI) SendInviteV3(
return inviteEvent, nil return inviteEvent, nil
} }
// MakeInviteCryptoIDs implements api.FederationInternalAPI
func (r *FederationInternalAPI) MakeInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) {
validRoomID, err := spec.NewRoomID(event.RoomID)
if err != nil {
return nil, err
}
verImpl, err := gomatrixserverlib.GetRoomVersion(version)
if err != nil {
return nil, err
}
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(event.SenderID))
if err != nil {
return nil, err
}
// TODO (devon): This should be allowed via a relay. Currently only transactions
// can be sent to relays. Would need to extend relays to handle invites.
if !r.shouldAttemptDirectFederation(invitee.Domain()) {
return nil, fmt.Errorf("relay servers have no meaningful response for invite.")
}
logrus.WithFields(logrus.Fields{
"user_id": invitee.String(),
"room_id": event.RoomID,
"room_version": version,
"destination": invitee.Domain(),
}).Info("Sending /send_invite")
inviteReq, err := fclient.NewInviteV3Request(event, version, strippedState)
if err != nil {
return nil, fmt.Errorf("gomatrixserverlib.NewInviteV3Request: %w", err)
}
inviteRes, err := r.federation.MakeInviteCryptoIDs(ctx, inviter.Domain(), invitee.Domain(), inviteReq, invitee)
if err != nil {
return nil, fmt.Errorf("r.federation.SendInviteCryptoIDs: failed to send invite: %w", err)
}
inviteEvent, err := verImpl.NewEventFromUntrustedJSON(inviteRes.Event)
if err != nil {
return nil, fmt.Errorf("r.federation.SendInviteCryptoIDs failed to decode event response: %w", err)
}
return inviteEvent, nil
}
// SendInviteCryptoIDs implements api.FederationInternalAPI
func (r *FederationInternalAPI) SendInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.PDU, invitee spec.UserID, version gomatrixserverlib.RoomVersion) error {
validRoomID := event.RoomID()
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, validRoomID, event.SenderID())
if err != nil {
return err
}
// TODO (devon): This should be allowed via a relay. Currently only transactions
// can be sent to relays. Would need to extend relays to handle invites.
if !r.shouldAttemptDirectFederation(invitee.Domain()) {
return fmt.Errorf("relay servers have no meaningful response for invite.")
}
logrus.WithFields(logrus.Fields{
"user_id": invitee.String(),
"room_id": event.RoomID,
"room_version": version,
"destination": invitee.Domain(),
}).Info("Sending /make_invite")
inviteReq, err := fclient.NewSendInviteCryptoIDsRequest(event, version)
if err != nil {
return fmt.Errorf("gomatrixserverlib.NewInviteV3Request: %w", err)
}
err = r.federation.SendInviteCryptoIDs(ctx, inviter.Domain(), invitee.Domain(), inviteReq, invitee)
if err != nil {
return fmt.Errorf("r.federation.SendInviteCryptoIDs: failed to send invite: %w", err)
}
return nil
}
// PerformServersAlive implements api.FederationInternalAPI // PerformServersAlive implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformBroadcastEDU( func (r *FederationInternalAPI) PerformBroadcastEDU(
ctx context.Context, ctx context.Context,

14
go.mod
View file

@ -1,5 +1,9 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
//replace github.com/matrix-org/gomatrixserverlib => ../../gomatrixserverlib/crypto-ids/
//replace github.com/matrix-org/gomatrixserverlib => /src/gmsl/
require ( require (
github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2 github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
@ -22,7 +26,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4 github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.17 github.com/mattn/go-sqlite3 v1.14.17
@ -42,12 +46,12 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.6 github.com/yggdrasil-network/yggdrasil-go v0.4.6
go.uber.org/atomic v1.10.0 go.uber.org/atomic v1.10.0
golang.org/x/crypto v0.13.0 golang.org/x/crypto v0.17.0
golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819
golang.org/x/image v0.5.0 golang.org/x/image v0.5.0
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
golang.org/x/sync v0.3.0 golang.org/x/sync v0.3.0
golang.org/x/term v0.12.0 golang.org/x/term v0.15.0
gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.4.0 gotest.tools/v3 v3.4.0
@ -124,8 +128,8 @@ require (
go.etcd.io/bbolt v1.3.6 // indirect go.etcd.io/bbolt v1.3.6 // indirect
golang.org/x/mod v0.12.0 // indirect golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.14.0 // indirect golang.org/x/net v0.14.0 // indirect
golang.org/x/sys v0.12.0 // indirect golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.12.0 // indirect golang.org/x/tools v0.12.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect

20
go.sum
View file

@ -208,8 +208,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4 h1:UuXfC7b29RBDfMdLmggeF3opu3XuGi8bNT9SKZtZc3I= github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862 h1:Kuya3qas85ZvVVkuOpemwhgvdJbLojvwvt3xyJTp1dY=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 h1:6t8kJr8i1/1I5nNttw6nn1ryQJgzVlBmSGgPiiaTdw4= github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 h1:6t8kJr8i1/1I5nNttw6nn1ryQJgzVlBmSGgPiiaTdw4=
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7/go.mod h1:ReWMS/LoVnOiRAdq9sNUC2NZnd1mZkMNB52QhpTRWjg= github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7/go.mod h1:ReWMS/LoVnOiRAdq9sNUC2NZnd1mZkMNB52QhpTRWjg=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@ -354,8 +354,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -418,19 +418,19 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View file

@ -85,13 +85,21 @@ func BuildEvent(
} }
builder := verImpl.NewEventBuilderFromProtoEvent(proto) builder := verImpl.NewEventBuilderFromProtoEvent(proto)
event, err := builder.Build( var event gomatrixserverlib.PDU
if identity.PrivateKey != nil {
event, err = builder.Build(
evTime, identity.ServerName, identity.KeyID, evTime, identity.ServerName, identity.KeyID,
identity.PrivateKey, identity.PrivateKey,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
event, err = builder.BuildWithoutSigning(evTime, identity.ServerName)
if err != nil {
return nil, err
}
}
return &types.HeaderedEvent{PDU: event}, nil return &types.HeaderedEvent{PDU: event}, nil
} }

View file

@ -92,6 +92,7 @@ type UserRoomPrivateKeyCreator interface {
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
} }
type InputRoomEventsAPI interface { type InputRoomEventsAPI interface {
@ -222,6 +223,7 @@ type ClientRoomserverAPI interface {
DefaultRoomVersionAPI DefaultRoomVersionAPI
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms. // QueryKnownUsers returns a list of users that we know about from our joined rooms.
@ -232,6 +234,7 @@ type ClientRoomserverAPI interface {
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) ([]gomatrixserverlib.PDU, error)
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
// PerformRoomUpgrade upgrades a room to a newer version // PerformRoomUpgrade upgrades a room to a newer version
PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error) PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
@ -241,9 +244,12 @@ type ClientRoomserverAPI interface {
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error) PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
PerformInvite(ctx context.Context, req *PerformInviteRequest) error PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error
PerformSendInviteCryptoIDs(ctx context.Context, req *PerformInviteRequestCryptoIDs) error
PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error)
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse, cryptoIDs bool) (gomatrixserverlib.PDU, error)
PerformPublish(ctx context.Context, req *PerformPublishRequest) error PerformPublish(ctx context.Context, req *PerformPublishRequest) error
// PerformForget forgets a rooms history for a specific user // PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
@ -305,7 +311,7 @@ type FederationRoomserverAPI interface {
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
PerformInvite(ctx context.Context, req *PerformInviteRequest) error PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error

View file

@ -29,6 +29,8 @@ type PerformCreateRoomRequest struct {
KeyID gomatrixserverlib.KeyID KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey PrivateKey ed25519.PrivateKey
EventTime time.Time EventTime time.Time
SenderID string
} }
type PerformJoinRequest struct { type PerformJoinRequest struct {
@ -38,6 +40,23 @@ type PerformJoinRequest struct {
Content map[string]interface{} `json:"content"` Content map[string]interface{} `json:"content"`
ServerNames []spec.ServerName `json:"server_names"` ServerNames []spec.ServerName `json:"server_names"`
Unsigned map[string]interface{} `json:"unsigned"` Unsigned map[string]interface{} `json:"unsigned"`
SenderID spec.SenderID
}
type PerformJoinRequestCryptoIDs struct {
RoomID string
UserID string
IsGuest bool
Content map[string]interface{}
ServerNames []spec.ServerName
Unsigned map[string]interface{}
JoinEvent gomatrixserverlib.PDU
}
type PerformInviteRequestCryptoIDs struct {
RoomID string
UserID spec.UserID
InviteEvent gomatrixserverlib.PDU
} }
type PerformLeaveRequest struct { type PerformLeaveRequest struct {

View file

@ -3,6 +3,7 @@ package internal
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -54,6 +55,7 @@ type RoomserverInternalAPI struct {
ServerACLs *acls.ServerACLs ServerACLs *acls.ServerACLs
fsAPI fsAPI.RoomserverFederationAPI fsAPI fsAPI.RoomserverFederationAPI
asAPI asAPI.AppServiceInternalAPI asAPI asAPI.AppServiceInternalAPI
usAPI userapi.RoomserverUserAPI
NATSClient *nats.Conn NATSClient *nats.Conn
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable string Durable string
@ -214,6 +216,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
r.Leaver.UserAPI = userAPI r.Leaver.UserAPI = userAPI
r.Inputer.UserAPI = userAPI r.Inputer.UserAPI = userAPI
r.usAPI = userAPI
} }
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
@ -251,24 +254,27 @@ func (r *RoomserverInternalAPI) PerformCreateRoom(
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
) error { cryptoIDs bool,
return r.Inviter.PerformInvite(ctx, req) ) (gomatrixserverlib.PDU, error) {
return r.Inviter.PerformInvite(ctx, req, cryptoIDs)
} }
func (r *RoomserverInternalAPI) PerformLeave( func (r *RoomserverInternalAPI) PerformLeave(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, res *api.PerformLeaveResponse,
) error { cryptoIDs bool,
outputEvents, err := r.Leaver.PerformLeave(ctx, req, res) ) (gomatrixserverlib.PDU, error) {
outputEvents, leaveEvent, err := r.Leaver.PerformLeave(ctx, req, res, cryptoIDs)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
return err return nil, err
} }
if len(outputEvents) == 0 { if len(outputEvents) == 0 {
return nil return leaveEvent, nil
} }
return r.OutputProducer.ProduceRoomEvents(req.RoomID, outputEvents) // TODO: cryptoIDs - what to do with this?
return leaveEvent, r.OutputProducer.ProduceRoomEvents(req.RoomID, outputEvents)
} }
func (r *RoomserverInternalAPI) PerformForget( func (r *RoomserverInternalAPI) PerformForget(
@ -308,6 +314,10 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
return err return err
} }
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.usAPI.ClaimOneTimeCryptoID(ctx, roomID, userID)
}
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String()) roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
if !ok { if !ok {
@ -330,6 +340,19 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(privKey)), ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(privKey)),
}, nil }, nil
} }
if roomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
sender, err := r.QuerySenderIDForUser(ctx, roomID, senderID)
if err != nil {
return fclient.SigningIdentity{}, err
} else if sender == nil {
return fclient.SigningIdentity{}, fmt.Errorf("no sender ID for %s in %s", senderID.String(), roomID.String())
}
return fclient.SigningIdentity{
PrivateKey: nil,
KeyID: "ed25519:1",
ServerName: spec.ServerName(*sender),
}, nil
}
identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain()) identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain())
if err != nil { if err != nil {
return fclient.SigningIdentity{}, err return fclient.SigningIdentity{}, err

View file

@ -445,7 +445,7 @@ func (r *Inputer) processRoomEvent(
} }
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one. // TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember { if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && event.Type() == spec.MRoomMember {
mapping := gomatrixserverlib.MemberContent{} mapping := gomatrixserverlib.MemberContent{}
if err = json.Unmarshal(event.Content(), &mapping); err != nil { if err = json.Unmarshal(event.Content(), &mapping); err != nil {
return err return err

View file

@ -179,7 +179,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Leaver: *fullUserID, Leaver: *fullUserID,
} }
leaveRes := &api.PerformLeaveResponse{} leaveRes := &api.PerformLeaveResponse{}
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes) outputEvents, _, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -44,6 +44,417 @@ type Creator struct {
RSAPI api.RoomserverInternalAPI RSAPI api.RoomserverInternalAPI
} }
// PerformCreateRoomCryptoIDs handles all the steps necessary to create a new room.
// nolint: gocyclo
func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) ([]gomatrixserverlib.PDU, error) {
verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion)
if err != nil {
return nil, spec.BadJSON("unknown room version")
}
createContent := map[string]interface{}{}
if len(createRequest.CreationContent) > 0 {
if err = json.Unmarshal(createRequest.CreationContent, &createContent); err != nil {
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed")
return nil, spec.BadJSON("invalid create content")
}
}
senderID := spec.SenderID(createRequest.SenderID)
// TODO: cryptoIDs - should we be assigning a room NID yet?
_, err = c.DB.AssignRoomNID(ctx, roomID, createRequest.RoomVersion)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to assign roomNID")
return nil, spec.InternalServerError{Err: err.Error()}
}
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String())
bytes := spec.Base64Bytes{}
err = bytes.Decode(string(senderID))
if err != nil {
return nil, err
}
if len(bytes) != ed25519.PublicKeySize {
return nil, spec.BadJSON("SenderID is not a valid ed25519 public key")
}
keyErr := c.RSAPI.StoreUserRoomPublicKey(ctx, senderID, userID, roomID)
if keyErr != nil {
util.GetLogger(ctx).WithError(keyErr).Error("StoreUserRoomPublicKey failed")
return nil, spec.InternalServerError{Err: keyErr.Error()}
}
}
createContent["creator"] = senderID
createContent["room_version"] = createRequest.RoomVersion
powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
joinRuleContent := gomatrixserverlib.JoinRuleContent{
JoinRule: spec.Invite,
}
historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{
HistoryVisibility: historyVisibilityShared,
}
if createRequest.PowerLevelContentOverride != nil {
// Merge powerLevelContentOverride fields by unmarshalling it atop the defaults
err = json.Unmarshal(createRequest.PowerLevelContentOverride, &powerLevelContent)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed")
return nil, spec.BadJSON("malformed power_level_content_override")
}
}
var guestsCanJoin bool
switch createRequest.StatePreset {
case spec.PresetPrivateChat:
joinRuleContent.JoinRule = spec.Invite
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
guestsCanJoin = true
case spec.PresetTrustedPrivateChat:
joinRuleContent.JoinRule = spec.Invite
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
for _, invitee := range createRequest.InvitedUsers {
powerLevelContent.Users[invitee] = 100
}
guestsCanJoin = true
case spec.PresetPublicChat:
joinRuleContent.JoinRule = spec.Public
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
}
createEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomCreate,
Content: createContent,
}
powerLevelEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomPowerLevels,
Content: powerLevelContent,
}
joinRuleEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomJoinRules,
Content: joinRuleContent,
}
historyVisibilityEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomHistoryVisibility,
Content: historyVisibilityContent,
}
membershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember,
StateKey: string(senderID),
}
memberContent := gomatrixserverlib.MemberContent{
Membership: spec.Join,
DisplayName: createRequest.UserDisplayName,
AvatarURL: createRequest.UserAvatarURL,
}
// If we are creating a room with pseudo IDs, create and sign the MXIDMapping
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
mapping := &gomatrixserverlib.MXIDMapping{
UserRoomKey: senderID,
UserID: userID.String(),
}
identity, idErr := c.Cfg.Matrix.SigningIdentityFor(userID.Domain()) // we MUST use the server signing mxid_mapping
if idErr != nil {
logrus.WithError(idErr).WithField("domain", userID.Domain()).Error("unable to find signing identity for domain")
return nil, spec.InternalServerError{Err: idErr.Error()}
}
// Sign the mapping with the server identity
if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil {
return nil, spec.InternalServerError{Err: err.Error()}
}
memberContent.MXIDMapping = mapping
}
membershipEvent.Content = memberContent
var nameEvent *gomatrixserverlib.FledglingEvent
var topicEvent *gomatrixserverlib.FledglingEvent
var guestAccessEvent *gomatrixserverlib.FledglingEvent
var aliasEvent *gomatrixserverlib.FledglingEvent
if createRequest.RoomName != "" {
nameEvent = &gomatrixserverlib.FledglingEvent{
Type: spec.MRoomName,
Content: eventutil.NameContent{
Name: createRequest.RoomName,
},
}
}
if createRequest.Topic != "" {
topicEvent = &gomatrixserverlib.FledglingEvent{
Type: spec.MRoomTopic,
Content: eventutil.TopicContent{
Topic: createRequest.Topic,
},
}
}
if guestsCanJoin {
guestAccessEvent = &gomatrixserverlib.FledglingEvent{
Type: spec.MRoomGuestAccess,
Content: eventutil.GuestAccessContent{
GuestAccess: "can_join",
},
}
}
var roomAlias string
if createRequest.RoomAliasName != "" {
roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, userID.Domain())
// check it's free
// TODO: This races but is better than nothing
hasAliasReq := api.GetRoomIDForAliasRequest{
Alias: roomAlias,
IncludeAppservices: false,
}
var aliasResp api.GetRoomIDForAliasResponse
err = c.RSAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
return nil, spec.InternalServerError{Err: err.Error()}
}
if aliasResp.RoomID != "" {
return nil, spec.RoomInUse("Room ID already exists.")
}
aliasEvent = &gomatrixserverlib.FledglingEvent{
Type: spec.MRoomCanonicalAlias,
Content: eventutil.CanonicalAlias{
Alias: roomAlias,
},
}
}
var initialStateEvents []gomatrixserverlib.FledglingEvent
for i := range createRequest.InitialState {
if createRequest.InitialState[i].StateKey != "" {
initialStateEvents = append(initialStateEvents, createRequest.InitialState[i])
continue
}
switch createRequest.InitialState[i].Type {
case spec.MRoomCreate:
continue
case spec.MRoomPowerLevels:
powerLevelEvent = createRequest.InitialState[i]
case spec.MRoomJoinRules:
joinRuleEvent = createRequest.InitialState[i]
case spec.MRoomHistoryVisibility:
historyVisibilityEvent = createRequest.InitialState[i]
case spec.MRoomGuestAccess:
guestAccessEvent = &createRequest.InitialState[i]
case spec.MRoomName:
nameEvent = &createRequest.InitialState[i]
case spec.MRoomTopic:
topicEvent = &createRequest.InitialState[i]
default:
initialStateEvents = append(initialStateEvents, createRequest.InitialState[i])
}
}
// send events into the room in order of:
// 1- m.room.create
// 2- room creator join member
// 3- m.room.power_levels
// 4- m.room.join_rules
// 5- m.room.history_visibility
// 6- m.room.canonical_alias (opt)
// 7- m.room.guest_access (opt)
// 8- other initial state items
// 9- m.room.name (opt)
// 10- m.room.topic (opt)
// 11- invite events (opt) - with is_direct flag if applicable TODO
// 12- 3pid invite events (opt) TODO
// This differs from Synapse slightly. Synapse would vary the ordering of 3-7
// depending on if those events were in "initial_state" or not. This made it
// harder to reason about, hence sticking to a strict static ordering.
eventsToMake := []gomatrixserverlib.FledglingEvent{
createEvent, membershipEvent, powerLevelEvent, joinRuleEvent, historyVisibilityEvent,
}
if guestAccessEvent != nil {
eventsToMake = append(eventsToMake, *guestAccessEvent)
}
eventsToMake = append(eventsToMake, initialStateEvents...)
if nameEvent != nil {
eventsToMake = append(eventsToMake, *nameEvent)
}
if topicEvent != nil {
eventsToMake = append(eventsToMake, *topicEvent)
}
if aliasEvent != nil {
// TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room.
// This means we might fail creating the alias but say the canonical alias is something that doesn't exist.
eventsToMake = append(eventsToMake, *aliasEvent)
}
var builtEvents []gomatrixserverlib.PDU
authEvents := gomatrixserverlib.NewAuthEvents(nil)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
return nil, spec.InternalServerError{Err: err.Error()}
}
for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1
builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
SenderID: string(senderID),
RoomID: roomID.String(),
Type: e.Type,
StateKey: &e.StateKey,
Depth: int64(depth),
})
err = builder.SetContent(e.Content)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
return nil, spec.InternalServerError{Err: err.Error()}
}
if i > 0 {
builder.PrevEvents = []string{builtEvents[i-1].EventID()}
}
var ev gomatrixserverlib.PDU
if err = builder.AddAuthEvents(&authEvents); err != nil {
util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed")
return nil, spec.InternalServerError{Err: err.Error()}
}
ev, err = builder.BuildWithoutSigning(createRequest.EventTime, userID.Domain())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
return nil, spec.InternalServerError{Err: err.Error()}
}
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return nil, spec.InternalServerError{Err: err.Error()}
}
// Add the event to the list of auth events
builtEvents = append(builtEvents, ev)
err = authEvents.AddEvent(ev)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed")
return nil, spec.InternalServerError{Err: err.Error()}
}
}
// TODO(#269): Reserve room alias while we create the room. This stops us
// from creating the room but still failing due to the alias having already
// been taken.
if roomAlias != "" {
aliasAlreadyExists, aliasErr := c.RSAPI.SetRoomAlias(ctx, senderID, roomID, roomAlias)
if aliasErr != nil {
util.GetLogger(ctx).WithError(aliasErr).Error("aliasAPI.SetRoomAlias failed")
return nil, spec.InternalServerError{Err: aliasErr.Error()}
}
if aliasAlreadyExists {
return nil, spec.RoomInUse("Room alias already exists.")
}
}
// TODO: cryptoIDs - this shouldn't really be done until the client calls /sendPDUs with these events
// But that would require the visibility setting also being passed along
if createRequest.Visibility == spec.Public {
// expose this room in the published room list
if err = c.RSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: roomID.String(),
Visibility: spec.Public,
}); err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to publish room")
return nil, spec.InternalServerError{Err: err.Error()}
}
}
// If this is a direct message then we should invite the participants.
//if len(createRequest.InvitedUsers) > 0 {
// Build some stripped state for the invite.
//var globalStrippedState []gomatrixserverlib.InviteStrippedState
//for _, event := range builtEvents {
// // Chosen events from the spec:
// // https://spec.matrix.org/v1.3/client-server-api/#stripped-state
// switch event.Type() {
// case spec.MRoomCreate:
// fallthrough
// case spec.MRoomName:
// fallthrough
// case spec.MRoomAvatar:
// fallthrough
// case spec.MRoomTopic:
// fallthrough
// case spec.MRoomCanonicalAlias:
// fallthrough
// case spec.MRoomEncryption:
// fallthrough
// case spec.MRoomMember:
// fallthrough
// case spec.MRoomJoinRules:
// ev := event
// globalStrippedState = append(
// globalStrippedState,
// gomatrixserverlib.NewInviteStrippedState(ev),
// )
// }
//}
// Process the invites.
//for _, invitee := range createRequest.InvitedUsers {
//inviteeUserID, userIDErr := spec.NewUserID(invitee, true)
//if userIDErr != nil {
// util.GetLogger(ctx).WithError(userIDErr).Error("invalid UserID")
// return nil, spec.InternalServerError{}
//}
// TODO: cryptoIDs - these shouldn't be here
// instead we should return proto invite events?
//err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
// InviteInput: api.InviteInput{
// RoomID: roomID,
// Inviter: userID,
// Invitee: *inviteeUserID,
// DisplayName: createRequest.UserDisplayName,
// AvatarURL: createRequest.UserAvatarURL,
// Reason: "",
// IsDirect: createRequest.IsDirect,
// KeyID: createRequest.KeyID,
// PrivateKey: createRequest.PrivateKey,
// EventTime: createRequest.EventTime,
// },
// InviteRoomState: globalStrippedState,
// SendAsServer: string(userID.Domain()),
//})
//switch e := err.(type) {
//case api.ErrInvalidID:
// return nil, spec.Unknown(e.Error())
//case api.ErrNotAllowed:
// return nil, spec.Forbidden(e.Error())
//case nil:
//default:
// util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
// sentry.CaptureException(err)
// return nil, spec.InternalServerError{}
//}
//}
//}
return builtEvents, nil
}
// PerformCreateRoom handles all the steps necessary to create a new room. // PerformCreateRoom handles all the steps necessary to create a new room.
// nolint: gocyclo // nolint: gocyclo
func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) { func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) {
@ -501,7 +912,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{ _, err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
InviteInput: api.InviteInput{ InviteInput: api.InviteInput{
RoomID: roomID, RoomID: roomID,
Inviter: userID, Inviter: userID,
@ -516,7 +927,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}, },
InviteRoomState: globalStrippedState, InviteRoomState: globalStrippedState,
SendAsServer: string(userID.Domain()), SendAsServer: string(userID.Domain()),
}) }, false)
switch e := err.(type) { switch e := err.(type) {
case api.ErrInvalidID: case api.ErrInvalidID:
return "", &util.JSONResponse{ return "", &util.JSONResponse{

View file

@ -125,16 +125,17 @@ func (r *Inviter) ProcessInviteMembership(
func (r *Inviter) PerformInvite( func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
) error { cryptoIDs bool,
) (gomatrixserverlib.PDU, error) {
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter) senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
if err != nil { if err != nil {
return err return nil, err
} else if senderID == nil { } else if senderID == nil {
return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID) return nil, fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
} }
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String()) info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
if err != nil { if err != nil {
return err return nil, err
} }
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
@ -152,20 +153,20 @@ func (r *Inviter) PerformInvite(
} }
if err = proto.SetContent(content); err != nil { if err = proto.SetContent(content); err != nil {
return err return nil, err
} }
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) { if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain()) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
signingKey := req.InviteInput.PrivateKey signingKey := req.InviteInput.PrivateKey
if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { if !cryptoIDs && info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID) signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
if err != nil { if err != nil {
return err return nil, err
} }
} }
@ -222,6 +223,10 @@ func (r *Inviter) PerformInvite(
} }
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID) return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
}, },
CryptoIDs: cryptoIDs,
ClaimSenderID: func(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.RSAPI.ClaimOneTimeSenderIDForUser(ctx, roomID, userID)
},
} }
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
@ -229,12 +234,14 @@ func (r *Inviter) PerformInvite(
switch e := err.(type) { switch e := err.(type) {
case spec.MatrixError: case spec.MatrixError:
if e.ErrCode == spec.ErrorForbidden { if e.ErrCode == spec.ErrorForbidden {
return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)} return nil, api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
} }
} }
return err return nil, err
} }
var response gomatrixserverlib.PDU
if !cryptoIDs {
// Send the invite event to the roomserver input stream. This will // Send the invite event to the roomserver input stream. This will
// notify existing users in the room about the invite, update the // notify existing users in the room about the invite, update the
// membership table and ensure that the event is ready and available // membership table and ensure that the event is ready and available
@ -254,8 +261,11 @@ func (r *Inviter) PerformInvite(
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
if err := inputRes.Err(); err != nil { if err := inputRes.Err(); err != nil {
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed") util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
return api.ErrNotAllowed{Err: err} return nil, api.ErrNotAllowed{Err: err}
}
} else {
response = inviteEvent
} }
return nil return response, nil
} }

View file

@ -406,6 +406,380 @@ func (r *Joiner) performJoinRoomByID(
return req.RoomIDOrAlias, userDomain, nil return req.RoomIDOrAlias, userDomain, nil
} }
func (r *Joiner) PerformSendJoinCryptoIDs(
ctx context.Context,
req *rsAPI.PerformJoinRequestCryptoIDs,
) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomID,
"user_id": req.UserID,
"servers": req.ServerNames,
})
logger.Info("performing send join")
res := fsAPI.PerformJoinResponse{}
r.FSAPI.PerformSendJoin(ctx, &fsAPI.PerformSendJoinRequestCryptoIDs{
RoomID: req.RoomID,
UserID: req.UserID,
ServerNames: req.ServerNames,
Unsigned: req.Unsigned,
Event: req.JoinEvent,
}, &res)
if res.LastError != nil {
return res.LastError
}
return nil
}
func (r *Joiner) PerformSendInviteCryptoIDs(
ctx context.Context,
req *rsAPI.PerformInviteRequestCryptoIDs,
) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomID,
"user_id": req.UserID,
})
logger.Info("performing send invite")
err := r.FSAPI.SendInviteCryptoIDs(ctx, req.InviteEvent, req.UserID, req.InviteEvent.Version())
if err != nil {
return err
}
return nil
}
// PerformJoin handles joining matrix rooms, including over federation by talking to the federationapi.
func (r *Joiner) PerformJoinCryptoIDs(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
) (joinEvent gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error) {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias,
"user_id": req.UserID,
"servers": req.ServerNames,
})
logger.Info("User requested to room join")
join, roomID, version, serverName, err := r.makeJoinEvent(context.Background(), req)
if err != nil {
logger.WithError(err).Error("Failed to make join room event")
sentry.CaptureException(err)
return nil, "", "", "", err
}
return join, roomID, version, serverName, nil
}
func (r *Joiner) makeJoinEvent(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
}
if !r.Cfg.Matrix.IsLocalServerName(domain) {
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
}
if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performJoinRoomByIDCryptoIDs(ctx, req)
}
if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performJoinRoomByAliasCryptoIDs(ctx, req)
}
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
}
func (r *Joiner) performJoinRoomByAliasCryptoIDs(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
// Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil {
return nil, "", "", "", fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias)
}
req.ServerNames = append(req.ServerNames, domain)
// Check if this alias matches our own server configuration. If it
// doesn't then we'll need to try a federated join.
var roomID string
if !r.Cfg.Matrix.IsLocalServerName(domain) {
// The alias isn't owned by us, so we will need to try joining using
// a remote server.
dirReq := fsAPI.PerformDirectoryLookupRequest{
RoomAlias: req.RoomIDOrAlias, // the room alias to lookup
ServerName: domain, // the server to ask
}
dirRes := fsAPI.PerformDirectoryLookupResponse{}
err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
if err != nil {
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
return nil, "", "", "", fmt.Errorf("looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
}
roomID = dirRes.RoomID
req.ServerNames = append(req.ServerNames, dirRes.ServerNames...)
} else {
var getRoomReq = rsAPI.GetRoomIDForAliasRequest{
Alias: req.RoomIDOrAlias,
IncludeAppservices: true,
}
var getRoomRes = rsAPI.GetRoomIDForAliasResponse{}
// Otherwise, look up if we know this room alias locally.
err = r.RSAPI.GetRoomIDForAlias(ctx, &getRoomReq, &getRoomRes)
if err != nil {
return nil, "", "", "", fmt.Errorf("lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
}
roomID = getRoomRes.RoomID
}
// If the room ID is empty then we failed to look up the alias.
if roomID == "" {
return nil, "", "", "", fmt.Errorf("alias %q not found", req.RoomIDOrAlias)
}
// If we do, then pluck out the room ID and continue the join.
req.RoomIDOrAlias = roomID
return r.performJoinRoomByIDCryptoIDs(ctx, req)
}
// TODO: Break this function up a bit & move to GMSL
// nolint:gocyclo
func (r *Joiner) performJoinRoomByIDCryptoIDs(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
// The original client request ?server_name=... may include this HS so filter that out so we
// don't attempt to make_join with ourselves
for i := 0; i < len(req.ServerNames); i++ {
if r.Cfg.Matrix.IsLocalServerName(req.ServerNames[i]) {
// delete this entry
req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...)
i--
}
}
// Get the domain part of the room ID.
roomID, err := spec.NewRoomID(req.RoomIDOrAlias)
if err != nil {
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
}
// If the server name in the room ID isn't ours then it's a
// possible candidate for finding the room via federation. Add
// it to the list of servers to try.
if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
req.ServerNames = append(req.ServerNames, roomID.Domain())
}
// Force a federated join if we aren't in the room and we've been
// given some server names to try joining by.
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
RoomID: req.RoomIDOrAlias,
}
inRoomRes := &rsAPI.QueryServerJoinedToRoomResponse{}
if err = r.Queryer.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
return nil, "", "", "", fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
}
serverInRoom := inRoomRes.IsInRoom
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
userID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
}
//TODO: CryptoIDs - what is provided & calculated senderIDs don't match?
// Look up the room NID for the supplied room ID.
var senderID spec.SenderID
checkInvitePending := false
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
if err == nil && info != nil {
switch info.RoomVersion {
case gomatrixserverlib.RoomVersionCryptoIDs:
senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
if queryErr == nil {
checkInvitePending = true
}
if senderIDPtr == nil {
senderID = req.SenderID
} else {
senderID = *senderIDPtr
}
default:
checkInvitePending = true
senderID = spec.SenderID(userID.String())
}
}
// Force a federated join if we're dealing with a pending invite
// and we aren't in the room.
if checkInvitePending {
isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
if inviteErr == nil && !serverInRoom && isInvitePending {
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
if queryErr != nil {
return nil, "", "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
}
// If we were invited by someone from another server then we can
// assume they are in the room so we can join via them.
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
req.ServerNames = append(req.ServerNames, inviter.Domain())
forceFederatedJoin = true
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
// only set unsigned if we've got a content.membership, which we _should_
if memberEvent.Get("content.membership").Exists() {
req.Unsigned = map[string]interface{}{
"prev_sender": memberEvent.Get("sender").Str,
"prev_content": map[string]interface{}{
"is_direct": memberEvent.Get("content.is_direct").Bool(),
"membership": memberEvent.Get("content.membership").Str,
},
}
}
}
}
}
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
if req.IsGuest {
var guestAccessEvent *types.HeaderedEvent
guestAccess := "forbidden"
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, spec.MRoomGuestAccess, "")
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
}
if guestAccessEvent != nil {
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
}
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
// is present on the room and has the guest_access value can_join.
if guestAccess != "can_join" {
return nil, "", "", "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("guest access is forbidden")}
}
}
// If we should do a forced federated join then do that.
if forceFederatedJoin {
joinEvent, version, serverName, federatedJoinErr := r.performFederatedMakeJoinByIDCryptoIDs(ctx, req)
return joinEvent, req.RoomIDOrAlias, version, serverName, federatedJoinErr
}
// Try to construct an actual join event from the template.
// If this succeeds then it is a sign that the room already exists
// locally on the homeserver.
// TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing.
var buildRes rsAPI.QueryLatestEventsAndStateResponse
identity := r.Cfg.Matrix.SigningIdentity
// at this point we know we have an existing room
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
mapping := &gomatrixserverlib.MXIDMapping{
UserRoomKey: senderID,
UserID: userID.String(),
}
// Sign the mapping with the server identity
if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil {
return nil, "", "", "", err
}
req.Content["mxid_mapping"] = mapping
// sign the event with the pseudo ID key
identity = fclient.SigningIdentity{
ServerName: spec.ServerName(senderID),
KeyID: "ed25519:1",
PrivateKey: nil,
}
}
senderIDString := string(senderID)
// Prepare the template for the join event.
proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember,
SenderID: senderIDString,
StateKey: &senderIDString,
RoomID: req.RoomIDOrAlias,
Redacts: "",
}
if err = proto.SetUnsigned(struct{}{}); err != nil {
return nil, "", "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
}
// It is possible for the request to include some "content" for the
// event. We'll always overwrite the "membership" key, but the rest,
// like "display_name" or "avatar_url", will be kept if supplied.
if req.Content == nil {
req.Content = map[string]interface{}{}
}
req.Content["membership"] = spec.Join
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
return nil, "", "", "", aerr
} else if authorisedVia != "" {
req.Content["join_authorised_via_users_server"] = authorisedVia
}
if err = proto.SetContent(req.Content); err != nil {
return nil, "", "", "", fmt.Errorf("eb.SetContent: %w", err)
}
joinEvent, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes)
switch err.(type) {
case nil:
// Do nothing
case eventutil.ErrRoomNoExists:
// The room doesn't exist locally. If the room ID looks like it should
// be ours then this probably means that we've nuked our database at
// some point.
if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
// If there are no more server names to try then give up here.
// Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers.
if len(req.ServerNames) == 0 {
return nil, "", "", "", eventutil.ErrRoomNoExists{}
}
}
// Perform a federated room join.
joinEvent, version, serverName, federatedJoinErr := r.performFederatedMakeJoinByIDCryptoIDs(ctx, req)
return joinEvent, req.RoomIDOrAlias, version, serverName, federatedJoinErr
default:
// Something else went wrong.
return nil, "", "", "", fmt.Errorf("error joining local room: %q", err)
}
// By this point, if req.RoomIDOrAlias contained an alias, then
// it will have been overwritten with a room ID by performJoinRoomByAlias.
// We should now include this in the response so that the CS API can
// return the right room ID.
return joinEvent, req.RoomIDOrAlias, inRoomRes.RoomVersion, userID.Domain(), nil
}
func (r *Joiner) performFederatedMakeJoinByIDCryptoIDs(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
// Try joining by all of the supplied server names.
fedReq := fsAPI.PerformJoinRequest{
RoomID: req.RoomIDOrAlias, // the room ID to try and join
UserID: req.UserID, // the user ID joining the room
ServerNames: req.ServerNames, // the server to try joining with
Content: req.Content, // the membership event content
Unsigned: req.Unsigned, // the unsigned event content, if any
}
joinEvent, version, serverName, err := r.FSAPI.PerformMakeJoin(ctx, &fedReq)
if err != nil {
return nil, "", "", err
}
return joinEvent, version, serverName, nil
}
func (r *Joiner) performFederatedJoinRoomByID( func (r *Joiner) performFederatedJoinRoomByID(
ctx context.Context, ctx context.Context,
req *rsAPI.PerformJoinRequest, req *rsAPI.PerformJoinRequest,

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -52,9 +53,10 @@ func (r *Leaver) PerformLeave(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, res *api.PerformLeaveResponse,
) ([]api.OutputEvent, error) { cryptoIDs bool,
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) { if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String()) return nil, nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
} }
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomID, "room_id": req.RoomID,
@ -62,15 +64,15 @@ func (r *Leaver) PerformLeave(
}) })
logger.Info("User requested to leave join") logger.Info("User requested to leave join")
if strings.HasPrefix(req.RoomID, "!") { if strings.HasPrefix(req.RoomID, "!") {
output, err := r.performLeaveRoomByID(context.Background(), req, res) output, event, err := r.performLeaveRoomByID(context.Background(), req, res, cryptoIDs)
if err != nil { if err != nil {
logger.WithError(err).Error("Failed to leave room") logger.WithError(err).Error("Failed to leave room")
} else { } else {
logger.Info("User left room successfully") logger.Info("User left room successfully")
} }
return output, err return output, event, err
} }
return nil, fmt.Errorf("room ID %q is invalid", req.RoomID) return nil, nil, fmt.Errorf("room ID %q is invalid", req.RoomID)
} }
// nolint:gocyclo // nolint:gocyclo
@ -78,14 +80,15 @@ func (r *Leaver) performLeaveRoomByID(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
) ([]api.OutputEvent, error) { cryptoIDs bool,
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
roomID, err := spec.NewRoomID(req.RoomID) roomID, err := spec.NewRoomID(req.RoomID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver) leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
if err != nil || leaver == nil { if err != nil || leaver == nil {
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String()) return nil, nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
} }
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
@ -94,7 +97,7 @@ func (r *Leaver) performLeaveRoomByID(
if err == nil && isInvitePending { if err == nil && isInvitePending {
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser) sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
if serr != nil { if serr != nil {
return nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr) return nil, nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr)
} }
var domain spec.ServerName var domain spec.ServerName
@ -107,7 +110,7 @@ func (r *Leaver) performLeaveRoomByID(
domain = sender.Domain() domain = sender.Domain()
} }
if !r.Cfg.Matrix.IsLocalServerName(domain) { if !r.Cfg.Matrix.IsLocalServerName(domain) {
return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver) return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver, cryptoIDs)
} }
// check that this is not a "server notice room" // check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{} accData := &userapi.QueryAccountDataResponse{}
@ -116,7 +119,7 @@ func (r *Leaver) performLeaveRoomByID(
RoomID: req.RoomID, RoomID: req.RoomID,
DataType: "m.tag", DataType: "m.tag",
}, accData); err != nil { }, accData); err != nil {
return nil, fmt.Errorf("unable to query account data: %w", err) return nil, nil, fmt.Errorf("unable to query account data: %w", err)
} }
if roomData, ok := accData.RoomAccountData[req.RoomID]; ok { if roomData, ok := accData.RoomAccountData[req.RoomID]; ok {
@ -124,13 +127,13 @@ func (r *Leaver) performLeaveRoomByID(
if ok { if ok {
tags := gomatrix.TagContent{} tags := gomatrix.TagContent{}
if err = json.Unmarshal(tagData, &tags); err != nil { if err = json.Unmarshal(tagData, &tags); err != nil {
return nil, fmt.Errorf("unable to unmarshal tag content") return nil, nil, fmt.Errorf("unable to unmarshal tag content")
} }
if _, ok = tags.Tags["m.server_notice"]; ok { if _, ok = tags.Tags["m.server_notice"]; ok {
// mimic the returned values from Synapse // mimic the returned values from Synapse
res.Message = "You cannot reject this invite" res.Message = "You cannot reject this invite"
res.Code = 403 res.Code = 403
return nil, spec.LeaveServerNoticeError() return nil, nil, spec.LeaveServerNoticeError()
} }
} }
} }
@ -149,22 +152,22 @@ func (r *Leaver) performLeaveRoomByID(
} }
latestRes := api.QueryLatestEventsAndStateResponse{} latestRes := api.QueryLatestEventsAndStateResponse{}
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil { if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
return nil, err return nil, nil, err
} }
if !latestRes.RoomExists { if !latestRes.RoomExists {
return nil, fmt.Errorf("room %q does not exist", req.RoomID) return nil, nil, fmt.Errorf("room %q does not exist", req.RoomID)
} }
// Now let's see if the user is in the room. // Now let's see if the user is in the room.
if len(latestRes.StateEvents) == 0 { if len(latestRes.StateEvents) == 0 {
return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID) return nil, nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
} }
membership, err := latestRes.StateEvents[0].Membership() membership, err := latestRes.StateEvents[0].Membership()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting membership: %w", err) return nil, nil, fmt.Errorf("error getting membership: %w", err)
} }
if membership != spec.Join && membership != spec.Invite { if membership != spec.Join && membership != spec.Invite {
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership) return nil, nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
} }
// Prepare the template for the leave event. // Prepare the template for the leave event.
@ -177,10 +180,10 @@ func (r *Leaver) performLeaveRoomByID(
Redacts: "", Redacts: "",
} }
if err = proto.SetContent(map[string]interface{}{"membership": "leave"}); err != nil { if err = proto.SetContent(map[string]interface{}{"membership": "leave"}); err != nil {
return nil, fmt.Errorf("eb.SetContent: %w", err) return nil, nil, fmt.Errorf("eb.SetContent: %w", err)
} }
if err = proto.SetUnsigned(struct{}{}); err != nil { if err = proto.SetUnsigned(struct{}{}); err != nil {
return nil, fmt.Errorf("eb.SetUnsigned: %w", err) return nil, nil, fmt.Errorf("eb.SetUnsigned: %w", err)
} }
// We know that the user is in the room at this point so let's build // We know that the user is in the room at this point so let's build
@ -190,19 +193,29 @@ func (r *Leaver) performLeaveRoomByID(
validRoomID, err := spec.NewRoomID(req.RoomID) validRoomID, err := spec.NewRoomID(req.RoomID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
var buildRes rsAPI.QueryLatestEventsAndStateResponse var buildRes rsAPI.QueryLatestEventsAndStateResponse
identity, err := r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver) var identity fclient.SigningIdentity
if !cryptoIDs {
identity, err = r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver)
if err != nil { if err != nil {
return nil, fmt.Errorf("SigningIdentityFor: %w", err) return nil, nil, fmt.Errorf("SigningIdentityFor: %w", err)
}
} else {
identity = fclient.SigningIdentity{
ServerName: spec.ServerName(*leaver),
KeyID: "ed25519:1",
PrivateKey: nil,
}
} }
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes) event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes)
if err != nil { if err != nil {
return nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err) return nil, nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err)
} }
if !cryptoIDs {
// Give our leave event to the roomserver input stream. The // Give our leave event to the roomserver input stream. The
// roomserver will process the membership change and notify // roomserver will process the membership change and notify
// downstream automatically. // downstream automatically.
@ -219,10 +232,11 @@ func (r *Leaver) performLeaveRoomByID(
inputRes := api.InputRoomEventsResponse{} inputRes := api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = inputRes.Err(); err != nil { if err = inputRes.Err(); err != nil {
return nil, fmt.Errorf("r.InputRoomEvents: %w", err) return nil, nil, fmt.Errorf("r.InputRoomEvents: %w", err)
}
} }
return nil, nil return nil, event, nil
} }
func (r *Leaver) performFederatedRejectInvite( func (r *Leaver) performFederatedRejectInvite(
@ -231,7 +245,8 @@ func (r *Leaver) performFederatedRejectInvite(
res *api.PerformLeaveResponse, // nolint:unparam res *api.PerformLeaveResponse, // nolint:unparam
inviteDomain spec.ServerName, eventID string, inviteDomain spec.ServerName, eventID string,
leaver spec.SenderID, leaver spec.SenderID,
) ([]api.OutputEvent, error) { cryptoIDs bool,
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
// Ask the federation sender to perform a federated leave for us. // Ask the federation sender to perform a federated leave for us.
leaveReq := fsAPI.PerformLeaveRequest{ leaveReq := fsAPI.PerformLeaveRequest{
RoomID: req.RoomID, RoomID: req.RoomID,
@ -239,7 +254,7 @@ func (r *Leaver) performFederatedRejectInvite(
ServerNames: []spec.ServerName{inviteDomain}, ServerNames: []spec.ServerName{inviteDomain},
} }
leaveRes := fsAPI.PerformLeaveResponse{} leaveRes := fsAPI.PerformLeaveResponse{}
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes, cryptoIDs); err != nil {
// failures in PerformLeave should NEVER stop us from telling other components like the // failures in PerformLeave should NEVER stop us from telling other components like the
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites. // sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
@ -279,5 +294,5 @@ func (r *Leaver) performFederatedRejectInvite(
TargetSenderID: leaver, TargetSenderID: leaver,
}, },
}, },
}, nil }, nil, nil
} }

View file

@ -1044,7 +1044,7 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
} }
switch version { switch version {
case gomatrixserverlib.RoomVersionPseudoIDs: case gomatrixserverlib.RoomVersionPseudoIDs, gomatrixserverlib.RoomVersionCryptoIDs:
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -16,7 +16,8 @@ type RoomServer struct {
} }
func (c *RoomServer) Defaults(opts DefaultOpts) { func (c *RoomServer) Defaults(opts DefaultOpts) {
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10 //c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionCryptoIDs
if opts.Generate { if opts.Generate {
if !opts.SingleDatabase { if !opts.SingleDatabase {
c.Database.ConnectionString = "file:roomserver.db" c.Database.ConnectionString = "file:roomserver.db"

View file

@ -46,6 +46,16 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
return nil return nil
} }
// OTCryptoIDCounts adds one-time pseudoID counts to the /sync response
func OTCryptoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
count, err := keyAPI.QueryOneTimeCryptoIDs(ctx, userID)
if err != nil {
return err
}
res.OTCryptoIDsCount = count.KeyCount
return nil
}
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information. // be already filled in with join/leave information.

View file

@ -50,7 +50,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
} }
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil return nil
}
func (a *mockKeyAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
return userapi.OneTimeCryptoIDsCount{}, nil
} }
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
return nil return nil

View file

@ -38,7 +38,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
} }
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response) err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
if err != nil { if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed") req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
return from
}
err = internal.OTCryptoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
if err != nil {
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
return from return from
} }

View file

@ -280,6 +280,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled { if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTK counts") syncReq.Log.WithError(err).Warn("failed to get OTK counts")
} }
err = internal.OTCryptoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
}
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -112,6 +112,10 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
return nil return nil
} }
func (a *syncUserAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
return userapi.OneTimeCryptoIDsCount{}, nil
}
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
return nil return nil
} }

View file

@ -153,7 +153,7 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDFor
// TODO: Set Signatures & Hashes fields // TODO: Set Signatures & Hashes fields
} }
if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { if format != FormatSyncFederation && (se.Version() == gomatrixserverlib.RoomVersionPseudoIDs || se.Version() == gomatrixserverlib.RoomVersionCryptoIDs) {
err := updatePseudoIDs(&ce, se, userIDForSender, format) err := updatePseudoIDs(&ce, se, userIDForSender, format)
if err != nil { if err != nil {
return nil, err return nil, err
@ -304,7 +304,7 @@ func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomS
return nil, err return nil, err
} }
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation { if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != FormatSyncFederation {
for i, ev := range inviteStateEvents { for i, ev := range inviteStateEvents {
userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID)) userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
if userIDErr != nil { if userIDErr != nil {

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -365,6 +366,7 @@ type Response struct {
ToDevice *ToDeviceResponse `json:"to_device,omitempty"` ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"` DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
OTCryptoIDsCount map[string]int `json:"one_time_cryptoids_count,omitempty"`
} }
func (r Response) MarshalJSON() ([]byte, error) { func (r Response) MarshalJSON() ([]byte, error) {
@ -427,6 +429,7 @@ func NewResponse() *Response {
res.DeviceLists = &DeviceLists{} res.DeviceLists = &DeviceLists{}
res.ToDevice = &ToDeviceResponse{} res.ToDevice = &ToDeviceResponse{}
res.DeviceListsOTKCount = map[string]int{} res.DeviceListsOTKCount = map[string]int{}
res.OTCryptoIDsCount = map[string]int{}
return &res return &res
} }
@ -530,6 +533,7 @@ type InviteResponse struct {
InviteState struct { InviteState struct {
Events []json.RawMessage `json:"events"` Events []json.RawMessage `json:"events"`
} `json:"invite_state"` } `json:"invite_state"`
OneTimeCryptoID string `json:"one_time_cryptoid,omitempty"`
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
@ -537,11 +541,17 @@ func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *t
res := InviteResponse{} res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{} res.InviteState.Events = []json.RawMessage{}
logrus.Infof("Room version: %s", event.Version())
if event.Version() == gomatrixserverlib.RoomVersionCryptoIDs {
logrus.Infof("Setting invite cryptoID to %s", *event.PDU.StateKey())
res.OneTimeCryptoID = *event.PDU.StateKey()
}
// First see if there's invite_room_state in the unsigned key of the invite. // First see if there's invite_room_state in the unsigned key of the invite.
// If there is then unmarshal it into the response. This will contain the // If there is then unmarshal it into the response. This will contain the
// partial room state such as join rules, room name etc. // partial room state such as join rules, room name etc.
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation { if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != synctypes.FormatSyncFederation {
updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, inviteRoomState, event.PDU, event.RoomID(), eventFormat) }, inviteRoomState, event.PDU, event.RoomID(), eventFormat)

View file

@ -51,6 +51,7 @@ type AppserviceUserAPI interface {
type RoomserverUserAPI interface { type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
} }
// api functions required by the media api // api functions required by the media api
@ -669,6 +670,7 @@ type UploadDeviceKeysAPI interface {
type SyncKeyAPI interface { type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
QueryOneTimeCryptoIDs(ctx context.Context, userID string) (OneTimeCryptoIDsCount, *KeyError)
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
} }
@ -772,12 +774,25 @@ type OneTimeKeys struct {
KeyJSON map[string]json.RawMessage KeyJSON map[string]json.RawMessage
} }
type OneTimeCryptoIDs struct {
// The user who owns this device
UserID string
// A map of algorithm:key_id => key JSON
KeyJSON map[string]json.RawMessage
}
// Split a key in KeyJSON into algorithm and key ID // Split a key in KeyJSON into algorithm and key ID
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":") segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1] return segments[0], segments[1]
} }
// Split a key in KeyJSON into algorithm and key ID
func (k *OneTimeCryptoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}
// OneTimeKeysCount represents the counts of one-time keys for a single device // OneTimeKeysCount represents the counts of one-time keys for a single device
type OneTimeKeysCount struct { type OneTimeKeysCount struct {
// The user who owns this device // The user who owns this device
@ -792,12 +807,23 @@ type OneTimeKeysCount struct {
KeyCount map[string]int KeyCount map[string]int
} }
type OneTimeCryptoIDsCount struct {
// The user who owns this device
UserID string
// algorithm to count e.g:
// {
// "pseudoid_curve25519": 10,
// }
KeyCount map[string]int
}
// PerformUploadKeysRequest is the request to PerformUploadKeys // PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct { type PerformUploadKeysRequest struct {
UserID string // Required - User performing the request UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys OneTimeKeys []OneTimeKeys
OneTimeCryptoIDs []OneTimeCryptoIDs
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key // the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
@ -812,6 +838,7 @@ type PerformUploadKeysResponse struct {
// A map of user_id -> device_id -> Error for tracking failures. // A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount OneTimeKeyCounts []OneTimeKeysCount
OneTimeCryptoIDCounts []OneTimeCryptoIDsCount
} }
// PerformDeleteKeysRequest asks the keyserver to forget about certain // PerformDeleteKeysRequest asks the keyserver to forget about certain

View file

@ -647,7 +647,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
user := "" user := ""
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil { if err == nil && sender != nil {
user = sender.String() user = sender.String()
} }
if user == mem.UserID { if user == mem.UserID {

View file

@ -17,9 +17,11 @@ package internal
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/ed25519"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"strings"
"sync" "sync"
"time" "time"
@ -55,11 +57,21 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.OneTimeKeys) > 0 { if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res)
} }
if len(req.OneTimeCryptoIDs) > 0 {
a.uploadOneTimeCryptoIDs(ctx, req, res)
}
logrus.Infof("One time cryptoIDs count before: %v", res.OneTimeCryptoIDCounts)
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil { if err != nil {
return err return err
} }
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
otpIDs, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
if err != nil {
return err
}
res.OneTimeCryptoIDCounts = []api.OneTimeCryptoIDsCount{*otpIDs}
logrus.Infof("One time cryptoIDs count after: %v", res.OneTimeCryptoIDCounts)
return nil return nil
} }
@ -181,6 +193,17 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
return nil return nil
} }
func (a *UserInternalAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (api.OneTimeCryptoIDsCount, *api.KeyError) {
count, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, userID)
if err != nil {
return api.OneTimeCryptoIDsCount{}, &api.KeyError{
Err: fmt.Sprintf("Failed to query OTID counts: %s", err),
}
}
return *count, nil
}
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil { if err != nil {
@ -773,6 +796,98 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
} }
func (a *UserInternalAPI) uploadOneTimeCryptoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
}
}
if len(req.OneTimeCryptoIDs) == 0 {
counts, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.KeyDatabase.OneTimeCryptoIDsCount: %s", err),
}
}
if counts != nil {
logrus.Infof("Uploading one-time cryptoIDs: early result count: %v", *counts)
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
}
return
}
for _, key := range req.OneTimeCryptoIDs {
// grab existing keys based on (user/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
i := 0
for keyIDWithAlgo := range key.KeyJSON {
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
existingKeys, err := a.KeyDatabase.ExistingOneTimeCryptoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time cryptoIDs: " + err.Error(),
})
continue
}
for keyIDWithAlgo := range existingKeys {
// if keys exist and the JSON doesn't match, error out as the key already exists
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time cryptoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
})
continue
}
}
// store one-time keys
counts, err := a.KeyDatabase.StoreOneTimeCryptoIDs(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time cryptoIDs: %s", req.UserID, req.DeviceID, err.Error()),
})
continue
}
// collect counts
logrus.Infof("Uploading one-time cryptoIDs: result count: %v", *counts)
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
}
}
type Ed25519Key struct {
Key spec.Base64Bytes `json:"key"`
}
func (a *UserInternalAPI) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
cryptoID, err := a.KeyDatabase.ClaimOneTimeCryptoID(ctx, userID, "ed25519")
if err != nil {
return "", err
}
logrus.Infof("Claimed one time cryptoID: %s", cryptoID)
if cryptoID != nil {
for key, value := range cryptoID.KeyJSON {
keyParts := strings.Split(key, ":")
if keyParts[0] == "ed25519" {
var key_bytes Ed25519Key
err := json.Unmarshal(value, &key_bytes)
if err != nil {
return "", err
}
length := len(key_bytes.Key)
if length != ed25519.PublicKeySize {
return "", fmt.Errorf("Invalid ed25519 public key, %d is the wrong size", length)
}
// TODO: cryptoIDs - store senderID for this user here?
return spec.SenderID(key_bytes.Key.Encode()), nil
}
}
}
return "", fmt.Errorf("failed claiming a valid one time cryptoID for this user: %s", userID.String())
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
// if we only want to update the display names, we can skip the checks below // if we only want to update the display names, we can skip the checks below
if onlyUpdateDisplayName { if onlyUpdateDisplayName {

View file

@ -175,6 +175,11 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device. // OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -0,0 +1,191 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimeCryptoIDsSchema = `
-- Stores one-time cryptoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
user_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- Clobber based on 3-uple of user/key/algorithm.
CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE (user_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
`
const upsertCryptoIDsSQL = "" +
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" +
" DO UPDATE SET key_json = $5"
const selectOneTimeCryptoIDsSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
const selectCryptoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimeCryptoIDSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectCryptoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimeCryptoIDsSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
type oneTimeCryptoIDsStatements struct {
db *sql.DB
upsertCryptoIDsStmt *sql.Stmt
selectCryptoIDsStmt *sql.Stmt
selectCryptoIDsCountStmt *sql.Stmt
selectCryptoIDByAlgorithmStmt *sql.Stmt
deleteOneTimeCryptoIDStmt *sql.Stmt
deleteOneTimeCryptoIDsStmt *sql.Stmt
}
func NewPostgresOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
s := &oneTimeCryptoIDsStatements{
db: db,
}
_, err := db.Exec(oneTimeCryptoIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
}.Prepare(db)
}
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
result := make(map[string]json.RawMessage)
var (
algorithmWithID string
keyJSONStr string
)
for rows.Next() {
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
return nil, err
}
result[algorithmWithID] = json.RawMessage(keyJSONStr)
}
return result, rows.Err()
}
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
counts := &api.OneTimeCryptoIDsCount{
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeCryptoIDsCount{
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, rows.Err()
}
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
return err
}

View file

@ -149,6 +149,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil { if err != nil {
return nil, err return nil, err
} }
otpid, err := NewPostgresOneTimeCryptoIDsTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db) dk, err := NewPostgresDeviceKeysTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -172,6 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{ return &shared.KeyDatabase{
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
OneTimeCryptoIDsTable: otpid,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc, KeyChangesTable: kc,
StaleDeviceListsTable: sdl, StaleDeviceListsTable: sdl,

View file

@ -65,6 +65,7 @@ type Database struct {
type KeyDatabase struct { type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys OneTimeKeysTable tables.OneTimeKeys
OneTimeCryptoIDsTable tables.OneTimeCryptoIDs
DeviceKeysTable tables.DeviceKeys DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists StaleDeviceListsTable tables.StaleDeviceLists
@ -945,6 +946,40 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
} }
func (d *KeyDatabase) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimeCryptoIDsTable.SelectOneTimeCryptoIDs(ctx, userID, keyIDsWithAlgorithms)
}
func (d *KeyDatabase) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (counts *api.OneTimeCryptoIDsCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimeCryptoIDsTable.InsertOneTimeCryptoIDs(ctx, txn, keys)
return err
})
return
}
func (d *KeyDatabase) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
return d.OneTimeCryptoIDsTable.CountOneTimeCryptoIDs(ctx, userID)
}
func (d *KeyDatabase) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) {
var result *api.OneTimeCryptoIDs
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
keyJSON, err := d.OneTimeCryptoIDsTable.SelectAndDeleteOneTimeCryptoID(ctx, txn, userID.String(), algorithm)
if err != nil {
return err
}
if keyJSON != nil {
result = &api.OneTimeCryptoIDs{
UserID: userID.String(),
KeyJSON: keyJSON,
}
}
return nil
})
return result, err
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
} }

View file

@ -0,0 +1,208 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
)
var oneTimeCryptoIDsSchema = `
-- Stores one-time cryptoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
user_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- Clobber based on 3-uple of user/key/algorithm.
UNIQUE (user_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
`
const upsertCryptoIDsSQL = "" +
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, key_id, algorithm)" +
" DO UPDATE SET key_json = $5"
const selectOneTimeCryptoIDsSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1"
const selectCryptoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimeCryptoIDSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectCryptoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimeCryptoIDsSQL = "" +
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
type oneTimeCryptoIDsStatements struct {
db *sql.DB
upsertCryptoIDsStmt *sql.Stmt
selectCryptoIDsStmt *sql.Stmt
selectCryptoIDsCountStmt *sql.Stmt
selectCryptoIDByAlgorithmStmt *sql.Stmt
deleteOneTimeCryptoIDStmt *sql.Stmt
deleteOneTimeCryptoIDsStmt *sql.Stmt
}
func NewSqliteOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
s := &oneTimeCryptoIDsStatements{
db: db,
}
_, err := db.Exec(oneTimeCryptoIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
}.Prepare(db)
}
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms {
wantSet[ka] = true
}
result := make(map[string]json.RawMessage)
for rows.Next() {
var keyID string
var algorithm string
var keyJSONStr string
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
return nil, err
}
keyIDWithAlgo := algorithm + ":" + keyID
if wantSet[keyIDWithAlgo] {
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
}
}
return result, rows.Err()
}
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
counts := &api.OneTimeCryptoIDsCount{
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(
ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs,
) (*api.OneTimeCryptoIDsCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeCryptoIDsCount{
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, rows.Err()
}
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
logrus.Warnf("No rows found for one time cryptoIDs")
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
if err != nil {
return nil, err
}
if keyJSON == "" {
logrus.Warnf("Empty key JSON for one time cryptoIDs")
return nil, nil
}
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
return err
}

View file

@ -146,6 +146,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil { if err != nil {
return nil, err return nil, err
} }
otpid, err := NewSqliteOneTimeCryptoIDsTable(db)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db) dk, err := NewSqliteDeviceKeysTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -169,6 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{ return &shared.KeyDatabase{
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
OneTimeCryptoIDsTable: otpid,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc, KeyChangesTable: kc,
StaleDeviceListsTable: sdl, StaleDeviceListsTable: sdl,

View file

@ -1,6 +1,7 @@
package storage_test package storage_test
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
} }
}) })
} }
func TestOneTimeCryptoIDs(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
userID := "@alice:localhost"
otk := api.OneTimeCryptoIDs{
UserID: userID,
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
}
// Add a one time pseudoID to the DB
_, err := db.StoreOneTimeCryptoIDs(ctx, otk)
MustNotError(t, err)
// Check the count of one time pseudoIDs is correct
count, err := db.OneTimeCryptoIDsCount(ctx, userID)
MustNotError(t, err)
if count.KeyCount["pseudoid_curve25519"] != 1 {
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
}
// Check the actual pseudoid contents are correct
keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
MustNotError(t, err)
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
MustNotError(t, err)
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"])
}
})
}

View file

@ -168,6 +168,14 @@ type OneTimeKeys interface {
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
} }
type OneTimeCryptoIDs interface {
SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
SelectAndDeleteOneTimeCryptoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error
}
type DeviceKeys interface { type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error