Compare commits
16 commits
main
...
crypto-ids
Author | SHA1 | Date | |
---|---|---|---|
2acc285f38 | |||
59753aba98 | |||
cbd547c828 | |||
1dc00f5fd2 | |||
b45e72830e | |||
3cbccb9ed7 | |||
5930a04044 | |||
227493cc5d | |||
7f7ac0f4fe | |||
b7d320f8d1 | |||
60be1391bf | |||
038103ac7f | |||
29cd14baf5 | |||
4bfcf27106 | |||
f17de49c6b | |||
a5ba533cfb |
|
@ -48,6 +48,7 @@ type createRoomRequest struct {
|
|||
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
|
||||
PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"`
|
||||
IsDirect bool `json:"is_direct"`
|
||||
CryptoID string `json:"cryptoid"`
|
||||
}
|
||||
|
||||
func (r createRoomRequest) Validate() *util.JSONResponse {
|
||||
|
@ -107,12 +108,27 @@ type createRoomResponse struct {
|
|||
RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec
|
||||
}
|
||||
|
||||
type createRoomCryptoIDsResponse struct {
|
||||
RoomID string `json:"room_id"`
|
||||
Version string `json:"room_version"`
|
||||
PDUs []json.RawMessage `json:"pdus"`
|
||||
}
|
||||
|
||||
func ToProtoEvents(ctx context.Context, events []gomatrixserverlib.PDU, rsAPI roomserverAPI.ClientRoomserverAPI) []json.RawMessage {
|
||||
result := make([]json.RawMessage, len(events))
|
||||
for i, event := range events {
|
||||
result[i] = json.RawMessage(event.JSON())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CreateRoom implements /createRoom
|
||||
func CreateRoom(
|
||||
req *http.Request, device *api.Device,
|
||||
cfg *config.ClientAPI,
|
||||
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
cryptoIDs bool,
|
||||
) util.JSONResponse {
|
||||
var createRequest createRoomRequest
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &createRequest)
|
||||
|
@ -129,10 +145,9 @@ func CreateRoom(
|
|||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
}
|
||||
return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime)
|
||||
return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime, cryptoIDs)
|
||||
}
|
||||
|
||||
// createRoom implements /createRoom
|
||||
func createRoom(
|
||||
ctx context.Context,
|
||||
createRequest createRoomRequest, device *api.Device,
|
||||
|
@ -140,6 +155,7 @@ func createRoom(
|
|||
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
evTime time.Time,
|
||||
cryptoIDs bool,
|
||||
) util.JSONResponse {
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
|
@ -225,6 +241,7 @@ func createRoom(
|
|||
EventTime: evTime,
|
||||
}
|
||||
|
||||
if !cryptoIDs {
|
||||
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
|
||||
if createRes != nil {
|
||||
return *createRes
|
||||
|
@ -239,4 +256,27 @@ func createRoom(
|
|||
Code: 200,
|
||||
JSON: response,
|
||||
}
|
||||
} else {
|
||||
req.SenderID = createRequest.CryptoID
|
||||
createEvents, err := rsAPI.PerformCreateRoomCryptoIDs(ctx, *userID, *roomID, &req)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("MakeCreateRoomEvents failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{Err: err.Error()},
|
||||
}
|
||||
}
|
||||
|
||||
response := createRoomCryptoIDsResponse{
|
||||
RoomID: roomID.String(),
|
||||
Version: string(roomVersion),
|
||||
PDUs: ToProtoEvents(ctx, createEvents, rsAPI),
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: response,
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func JoinRoomByIDOrAlias(
|
||||
|
@ -144,3 +145,139 @@ func JoinRoomByIDOrAlias(
|
|||
return result
|
||||
}
|
||||
}
|
||||
|
||||
type joinRoomCryptoIDsResponse struct {
|
||||
RoomID string `json:"room_id"`
|
||||
Version string `json:"room_version"`
|
||||
ViaServer string `json:"via_server"`
|
||||
PDU json.RawMessage `json:"pdu"`
|
||||
}
|
||||
|
||||
func JoinRoomByIDOrAliasCryptoIDs(
|
||||
req *http.Request,
|
||||
device *api.Device,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
profileAPI api.ClientUserAPI,
|
||||
roomIDOrAlias string,
|
||||
) util.JSONResponse {
|
||||
// Prepare to ask the roomserver to perform the room join.
|
||||
joinReq := roomserverAPI.PerformJoinRequest{
|
||||
RoomIDOrAlias: roomIDOrAlias,
|
||||
UserID: device.UserID,
|
||||
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||
Content: map[string]interface{}{},
|
||||
}
|
||||
|
||||
// Check to see if any ?server_name= query parameters were
|
||||
// given in the request.
|
||||
if serverNames, ok := req.URL.Query()["server_name"]; ok {
|
||||
for _, serverName := range serverNames {
|
||||
joinReq.ServerNames = append(
|
||||
joinReq.ServerNames,
|
||||
spec.ServerName(serverName),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// If content was provided in the request then include that
|
||||
// in the request. It'll get used as a part of the membership
|
||||
// event content.
|
||||
_ = httputil.UnmarshalJSONRequest(req, &joinReq.Content)
|
||||
|
||||
if senderid, ok := joinReq.Content["cryptoid"]; ok {
|
||||
logrus.Errorf("CryptoID: %s", senderid.(string))
|
||||
joinReq.SenderID = spec.SenderID(senderid.(string))
|
||||
} else {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown("Missing cryptoid in request body"),
|
||||
}
|
||||
}
|
||||
delete(joinReq.Content, "cryptoid")
|
||||
|
||||
// Work out our localpart for the client profile request.
|
||||
|
||||
// Request our profile content to populate the request content with.
|
||||
profile, err := profileAPI.QueryProfile(req.Context(), device.UserID)
|
||||
|
||||
switch err {
|
||||
case nil:
|
||||
joinReq.Content["displayname"] = profile.DisplayName
|
||||
joinReq.Content["avatar_url"] = profile.AvatarURL
|
||||
case appserviceAPI.ErrProfileNotExists:
|
||||
util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.Unknown("Unable to query user profile, no profile found."),
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// Ask the roomserver to perform the join.
|
||||
done := make(chan util.JSONResponse, 1)
|
||||
go func() {
|
||||
defer close(done)
|
||||
joinEvent, roomID, version, serverName, err := rsAPI.PerformJoinCryptoIDs(req.Context(), &joinReq)
|
||||
var response util.JSONResponse
|
||||
|
||||
switch e := err.(type) {
|
||||
case nil: // success case
|
||||
response = util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: joinRoomCryptoIDsResponse{
|
||||
RoomID: roomID,
|
||||
Version: string(version),
|
||||
ViaServer: string(serverName),
|
||||
PDU: json.RawMessage(joinEvent.JSON()),
|
||||
},
|
||||
}
|
||||
case roomserverAPI.ErrInvalidID:
|
||||
response = util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.Unknown(e.Error()),
|
||||
}
|
||||
case roomserverAPI.ErrNotAllowed:
|
||||
jsonErr := spec.Forbidden(e.Error())
|
||||
if device.AccountType == api.AccountTypeGuest {
|
||||
jsonErr = spec.GuestAccessForbidden(e.Error())
|
||||
}
|
||||
response = util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonErr,
|
||||
}
|
||||
case *gomatrix.HTTPError: // this ensures we proxy responses over federation to the client
|
||||
response = util.JSONResponse{
|
||||
Code: e.Code,
|
||||
JSON: json.RawMessage(e.Message),
|
||||
}
|
||||
case eventutil.ErrRoomNoExists:
|
||||
response = util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound(e.Error()),
|
||||
}
|
||||
default:
|
||||
response = util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
done <- response
|
||||
}()
|
||||
|
||||
// Wait either for the join to finish, or for us to hit a reasonable
|
||||
// timeout, at which point we'll just return a 200 to placate clients.
|
||||
timer := time.NewTimer(time.Second * 20)
|
||||
select {
|
||||
case <-timer.C:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusRequestTimeout,
|
||||
JSON: spec.Unknown("Failed creating join event with the remote server."),
|
||||
}
|
||||
case result := <-done:
|
||||
// Stop and drain the timer
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
|
|||
Preset: spec.PresetPublicChat,
|
||||
RoomAliasName: "alias",
|
||||
Invite: []string{bob.ID},
|
||||
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now(), false)
|
||||
crResp, ok := resp.JSON.(createRoomResponse)
|
||||
if !ok {
|
||||
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||
|
@ -81,7 +81,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
|
|||
Visibility: "public",
|
||||
Preset: spec.PresetPublicChat,
|
||||
Invite: []string{charlie.ID},
|
||||
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now(), false)
|
||||
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
|
||||
if !ok {
|
||||
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||
|
|
|
@ -97,6 +97,92 @@ type queryKeysRequest struct {
|
|||
DeviceKeys map[string][]string `json:"device_keys"`
|
||||
}
|
||||
|
||||
type uploadKeysCryptoIDsRequest struct {
|
||||
DeviceKeys json.RawMessage `json:"device_keys"`
|
||||
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
|
||||
OneTimeCryptoIDs map[string]json.RawMessage `json:"one_time_cryptoids"`
|
||||
}
|
||||
|
||||
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
|
||||
var r uploadKeysCryptoIDsRequest
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
uploadReq := &api.PerformUploadKeysRequest{
|
||||
DeviceID: device.ID,
|
||||
UserID: device.UserID,
|
||||
}
|
||||
if r.DeviceKeys != nil {
|
||||
uploadReq.DeviceKeys = []api.DeviceKeys{
|
||||
{
|
||||
DeviceID: device.ID,
|
||||
UserID: device.UserID,
|
||||
KeyJSON: r.DeviceKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
if r.OneTimeKeys != nil {
|
||||
uploadReq.OneTimeKeys = []api.OneTimeKeys{
|
||||
{
|
||||
DeviceID: device.ID,
|
||||
UserID: device.UserID,
|
||||
KeyJSON: r.OneTimeKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
if r.OneTimeCryptoIDs != nil {
|
||||
uploadReq.OneTimeCryptoIDs = []api.OneTimeCryptoIDs{
|
||||
{
|
||||
UserID: device.UserID,
|
||||
KeyJSON: r.OneTimeCryptoIDs,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
util.GetLogger(req.Context()).
|
||||
WithField("device keys", r.DeviceKeys).
|
||||
WithField("one-time keys", r.OneTimeKeys).
|
||||
WithField("one-time cryptoids", r.OneTimeCryptoIDs).
|
||||
Info("Uploading keys")
|
||||
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
if uploadRes.Error != nil {
|
||||
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if len(uploadRes.KeyErrors) > 0 {
|
||||
util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys")
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: uploadRes.KeyErrors,
|
||||
}
|
||||
}
|
||||
|
||||
keyCount := make(map[string]int)
|
||||
if len(uploadRes.OneTimeKeyCounts) > 0 {
|
||||
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
|
||||
}
|
||||
cryptoIDCount := make(map[string]int)
|
||||
if len(uploadRes.OneTimeCryptoIDCounts) > 0 {
|
||||
cryptoIDCount = uploadRes.OneTimeCryptoIDCounts[0].KeyCount
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: struct {
|
||||
OTKCounts interface{} `json:"one_time_key_counts"`
|
||||
OTIDCounts interface{} `json:"one_time_cryptoid_counts"`
|
||||
}{keyCount, cryptoIDCount},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *queryKeysRequest) GetTimeout() time.Duration {
|
||||
if r.Timeout == 0 {
|
||||
return 10 * time.Second
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
@ -23,11 +24,16 @@ import (
|
|||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
type leaveRoomCryptoIDsResponse struct {
|
||||
PDU json.RawMessage `json:"pdu"`
|
||||
}
|
||||
|
||||
func LeaveRoomByID(
|
||||
req *http.Request,
|
||||
device *api.Device,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
roomID string,
|
||||
cryptoIDs bool,
|
||||
) util.JSONResponse {
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
|
@ -45,7 +51,8 @@ func LeaveRoomByID(
|
|||
leaveRes := roomserverAPI.PerformLeaveResponse{}
|
||||
|
||||
// Ask the roomserver to perform the leave.
|
||||
if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil {
|
||||
leaveEvent, err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes, cryptoIDs)
|
||||
if err != nil {
|
||||
if leaveRes.Code != 0 {
|
||||
return util.JSONResponse{
|
||||
Code: leaveRes.Code,
|
||||
|
@ -58,8 +65,15 @@ func LeaveRoomByID(
|
|||
}
|
||||
}
|
||||
|
||||
if cryptoIDs {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: leaveRoomCryptoIDsResponse{
|
||||
PDU: json.RawMessage(leaveEvent.JSON()),
|
||||
},
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ package routing
|
|||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -39,10 +40,15 @@ import (
|
|||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
type membershipCryptoIDsResponse struct {
|
||||
PDU json.RawMessage `json:"pdu"`
|
||||
}
|
||||
|
||||
func SendBan(
|
||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
cryptoIDs bool,
|
||||
) util.JSONResponse {
|
||||
body, evTime, reqErr := extractRequestData(req)
|
||||
if reqErr != nil {
|
||||
|
@ -95,13 +101,14 @@ func SendBan(
|
|||
}
|
||||
}
|
||||
|
||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI)
|
||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
|
||||
}
|
||||
|
||||
func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI) util.JSONResponse {
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, cryptoIDs bool) util.JSONResponse {
|
||||
|
||||
// TODO: cryptoIDs - what about when we don't know the senderID for a user?
|
||||
event, err := buildMembershipEvent(
|
||||
ctx, targetUserID, reason, profileAPI, device, membership,
|
||||
roomID, false, cfg, evTime, rsAPI, asAPI,
|
||||
|
@ -114,6 +121,8 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
|
|||
}
|
||||
}
|
||||
|
||||
var json interface{}
|
||||
if !cryptoIDs {
|
||||
serverName := device.UserDomain()
|
||||
if err = roomserverAPI.SendEvents(
|
||||
ctx, rsAPI,
|
||||
|
@ -131,10 +140,14 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
|
|||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
json = struct{}{}
|
||||
} else {
|
||||
json = membershipCryptoIDsResponse{PDU: event.JSON()}
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
JSON: json,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,6 +155,7 @@ func SendKick(
|
|||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
cryptoIDs bool,
|
||||
) util.JSONResponse {
|
||||
body, evTime, reqErr := extractRequestData(req)
|
||||
if reqErr != nil {
|
||||
|
@ -216,13 +230,14 @@ func SendKick(
|
|||
}
|
||||
}
|
||||
// TODO: should we be using SendLeave instead?
|
||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI)
|
||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
|
||||
}
|
||||
|
||||
func SendUnban(
|
||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
cryptoIDs bool,
|
||||
) util.JSONResponse {
|
||||
body, evTime, reqErr := extractRequestData(req)
|
||||
if reqErr != nil {
|
||||
|
@ -272,13 +287,14 @@ func SendUnban(
|
|||
}
|
||||
}
|
||||
// TODO: should we be using SendLeave instead?
|
||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI)
|
||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
|
||||
}
|
||||
|
||||
func SendInvite(
|
||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||
roomID string, cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
cryptoIDs bool,
|
||||
) util.JSONResponse {
|
||||
body, evTime, reqErr := extractRequestData(req)
|
||||
if reqErr != nil {
|
||||
|
@ -323,7 +339,7 @@ func SendInvite(
|
|||
}
|
||||
|
||||
// We already received the return value, so no need to check for an error here.
|
||||
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime)
|
||||
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime, cryptoIDs)
|
||||
return response
|
||||
}
|
||||
|
||||
|
@ -336,6 +352,7 @@ func sendInvite(
|
|||
cfg *config.ClientAPI,
|
||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
|
||||
cryptoIDs bool,
|
||||
) (util.JSONResponse, error) {
|
||||
validRoomID, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
|
@ -372,7 +389,7 @@ func sendInvite(
|
|||
JSON: spec.InternalServerError{},
|
||||
}, err
|
||||
}
|
||||
err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||
inviteEvent, err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||
InviteInput: roomserverAPI.InviteInput{
|
||||
RoomID: *validRoomID,
|
||||
Inviter: *inviter,
|
||||
|
@ -387,7 +404,7 @@ func sendInvite(
|
|||
},
|
||||
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
|
||||
SendAsServer: string(device.UserDomain()),
|
||||
})
|
||||
}, cryptoIDs)
|
||||
|
||||
switch e := err.(type) {
|
||||
case roomserverAPI.ErrInvalidID:
|
||||
|
@ -410,10 +427,22 @@ func sendInvite(
|
|||
}, err
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
response := util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: struct{}{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type inviteCryptoIDResponse struct {
|
||||
PDU json.RawMessage `json:"pdu"`
|
||||
}
|
||||
|
||||
if inviteEvent != nil {
|
||||
response.JSON = inviteCryptoIDResponse{
|
||||
PDU: json.RawMessage(inviteEvent.JSON()),
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func buildMembershipEventDirect(
|
||||
|
|
|
@ -309,9 +309,33 @@ func Setup(
|
|||
|
||||
v3mux.Handle("/createRoom",
|
||||
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
|
||||
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI, false)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/createRoom",
|
||||
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/createRoom")
|
||||
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI, true)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}",
|
||||
httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/send_pdus")
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
txnID := vars["txnID"]
|
||||
|
||||
// NOTE: when making events such as for create_room, multiple PDUs will need to be passed between the client & server.
|
||||
return SendPDUs(req, device, cfg, userAPI, rsAPI, asAPI, &txnID)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/join/{roomIDOrAlias}",
|
||||
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
|
@ -334,8 +358,32 @@ func Setup(
|
|||
return resp.(util.JSONResponse)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/join/{roomIDOrAlias}",
|
||||
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/join/{roomIDOrAlias}")
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
// Only execute a join for roomIDOrAlias and UserID once. If there is a join in progress
|
||||
// it waits for it to complete and returns that result for subsequent requests.
|
||||
resp, _, _ := sf.Do(vars["roomIDOrAlias"]+device.UserID, func() (any, error) {
|
||||
return JoinRoomByIDOrAliasCryptoIDs(
|
||||
req, device, rsAPI, userAPI, vars["roomIDOrAlias"],
|
||||
), nil
|
||||
})
|
||||
// once all joins are processed, drop them from the cache. Further requests
|
||||
// will be processed as usual.
|
||||
sf.Forget(vars["roomIDOrAlias"] + device.UserID)
|
||||
return resp.(util.JSONResponse)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
if mscCfg.Enabled("msc2753") {
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/peek/{roomIDOrAlias}",
|
||||
httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
|
@ -378,6 +426,29 @@ func Setup(
|
|||
return resp.(util.JSONResponse)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/join",
|
||||
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/join")
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
// Only execute a join for roomID and UserID once. If there is a join in progress
|
||||
// it waits for it to complete and returns that result for subsequent requests.
|
||||
resp, _, _ := sf.Do(vars["roomID"]+device.UserID, func() (any, error) {
|
||||
return JoinRoomByIDOrAliasCryptoIDs(
|
||||
req, device, rsAPI, userAPI, vars["roomID"],
|
||||
), nil
|
||||
})
|
||||
// once all joins are processed, drop them from the cache. Further requests
|
||||
// will be processed as usual.
|
||||
sf.Forget(vars["roomID"] + device.UserID)
|
||||
return resp.(util.JSONResponse)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/leave",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
|
@ -388,10 +459,26 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return LeaveRoomByID(
|
||||
req, device, rsAPI, vars["roomID"],
|
||||
req, device, rsAPI, vars["roomID"], false,
|
||||
)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/leave",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/leave")
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return LeaveRoomByID(
|
||||
req, device, rsAPI, vars["roomID"], true,
|
||||
)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/rooms/{roomID}/unpeek",
|
||||
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -409,7 +496,17 @@ func Setup(
|
|||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/ban",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/ban")
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/invite",
|
||||
|
@ -421,7 +518,20 @@ func Setup(
|
|||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/invite",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/invite")
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/kick",
|
||||
|
@ -430,7 +540,17 @@ func Setup(
|
|||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/kick",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/kick")
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/unban",
|
||||
|
@ -439,7 +559,17 @@ func Setup(
|
|||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/unban",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/unban")
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
|
||||
|
@ -451,6 +581,16 @@ func Setup(
|
|||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/send/{eventType}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/send/{eventType}")
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -462,6 +602,18 @@ func Setup(
|
|||
nil, cfg, rsAPI, transactionsCache)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/send/{eventType}/{txnID}")
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
txnID := vars["txnID"]
|
||||
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], &txnID,
|
||||
nil, cfg, rsAPI, transactionsCache)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -510,6 +662,18 @@ func Setup(
|
|||
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/state/{eventType:[^/]+/?}")
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
emptyString := ""
|
||||
eventType := strings.TrimSuffix(vars["eventType"], "/")
|
||||
return SendEventCryptoIDs(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
@ -521,6 +685,17 @@ func Setup(
|
|||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/state/{eventType}/{stateKey}")
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
stateKey := vars["stateKey"]
|
||||
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
// Defined outside of handler to persist between calls
|
||||
// TODO: clear based on some criteria
|
||||
|
@ -559,6 +734,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/directory/room/{roomAlias}",
|
||||
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -569,6 +745,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/directory/room/{roomAlias}",
|
||||
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -636,6 +813,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/rooms/{roomID}/typing/{userID}",
|
||||
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
|
@ -648,6 +826,7 @@ func Setup(
|
|||
return SendTyping(req, device, vars["roomID"], vars["userID"], rsAPI, syncProducer)
|
||||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
|
||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -657,6 +836,7 @@ func Setup(
|
|||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -668,6 +848,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
||||
httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -1118,6 +1299,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/rooms/{roomID}/read_markers",
|
||||
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
|
@ -1144,6 +1326,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/rooms/{roomID}/upgrade",
|
||||
httputil.MakeAuthAPI("rooms_upgrade", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
@ -1472,6 +1655,12 @@ func Setup(
|
|||
return UploadKeys(req, userAPI, device)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/keys/upload",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/keys/upload")
|
||||
return UploadKeysCryptoIDs(req, userAPI, device)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/query",
|
||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return QueryKeys(req, userAPI, device)
|
||||
|
@ -1495,6 +1684,7 @@ func Setup(
|
|||
return SetReceipt(req, userAPI, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
// TODO: update for cryptoIDs
|
||||
v3mux.Handle("/presence/{userId}/status",
|
||||
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
|
272
clientapi/routing/send_pdus.go
Normal file
272
clientapi/routing/send_pdus.go
Normal file
|
@ -0,0 +1,272 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type PDUInfo struct {
|
||||
Version string `json:"room_version"`
|
||||
ViaServer string `json:"via_server,omitempty"`
|
||||
PDU json.RawMessage `json:"pdu"`
|
||||
}
|
||||
|
||||
type sendPDUsRequest struct {
|
||||
PDUs []PDUInfo `json:"pdus"`
|
||||
}
|
||||
|
||||
// SendPDUs implements /sendPDUs
|
||||
// nolint:gocyclo
|
||||
func SendPDUs(
|
||||
req *http.Request, device *api.Device,
|
||||
cfg *config.ClientAPI,
|
||||
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||
txnID *string,
|
||||
) util.JSONResponse {
|
||||
// TODO: cryptoIDs - should this include an "eventType"?
|
||||
// if it's a bulk send endpoint, I don't think that makes any sense since there are multiple event types
|
||||
// In that case, how do I know how to treat the events?
|
||||
// I could sort them all by roomID?
|
||||
// Then filter them down based on event type? (how do I collect groups of events such as for room creation?)
|
||||
// Possibly based on event hash tracking that I know were sent to the client?
|
||||
// For createRoom, I know what the possible list of events are, so try to find those and collect them to send to room creation.
|
||||
// Could also sort by depth... but that seems dangerous and depth may not be a field forever
|
||||
// Does it matter at all?
|
||||
// Can't I just forward all the events to the roomserver?
|
||||
// Do I need to do any specific processing on them?
|
||||
|
||||
var pdus sendPDUsRequest
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &pdus)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
// create a mutex for the specific user in the specific room
|
||||
// this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order
|
||||
// TODO: cryptoIDs - where to get roomID from?
|
||||
mutex, _ := userRoomSendMutexes.LoadOrStore("roomID"+userID.String(), &sync.Mutex{})
|
||||
mutex.(*sync.Mutex).Lock()
|
||||
defer mutex.(*sync.Mutex).Unlock()
|
||||
|
||||
inputs := make([]roomserverAPI.InputRoomEvent, 0, len(pdus.PDUs))
|
||||
for _, event := range pdus.PDUs {
|
||||
// TODO: cryptoIDs - event hash check?
|
||||
verImpl, err := gomatrixserverlib.GetRoomVersion(gomatrixserverlib.RoomVersion(event.Version))
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{Err: err.Error()},
|
||||
}
|
||||
}
|
||||
//eventJSON, err := json.Marshal(event)
|
||||
//if err != nil {
|
||||
// return util.JSONResponse{
|
||||
// Code: http.StatusInternalServerError,
|
||||
// JSON: spec.InternalServerError{Err: err.Error()},
|
||||
// }
|
||||
//}
|
||||
// TODO: cryptoIDs - how should we be converting to a PDU here?
|
||||
// if the hash matches an event we sent to the client, then the JSON should be good.
|
||||
// But how do we know how to fill out if the event is redacted if we use the trustedJSON function?
|
||||
// Also - untrusted JSON seems better - except it strips off the unsigned field?
|
||||
// Also - gmsl events don't store the `hashes` field... problem?
|
||||
|
||||
pdu, err := verImpl.NewEventFromUntrustedJSON(event.PDU)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{Err: err.Error()},
|
||||
}
|
||||
}
|
||||
|
||||
util.GetLogger(req.Context()).Infof("Processing %s event (%s): %s", pdu.Type(), pdu.EventID(), pdu.JSON())
|
||||
|
||||
// Check that the event is signed by the server sending the request.
|
||||
redacted, err := verImpl.RedactEventJSON(pdu.JSON())
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("RedactEventJSON failed")
|
||||
continue
|
||||
}
|
||||
|
||||
verifier := gomatrixserverlib.JSONVerifierSelf{}
|
||||
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
|
||||
ServerName: spec.ServerName(pdu.SenderID()),
|
||||
Message: redacted,
|
||||
AtTS: pdu.OriginServerTS(),
|
||||
ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck,
|
||||
}}
|
||||
verifyResults, err := verifier.VerifyJSONs(req.Context(), verifyRequests)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("keys.VerifyJSONs failed")
|
||||
continue
|
||||
}
|
||||
if verifyResults[0].Error != nil {
|
||||
util.GetLogger(req.Context()).WithError(verifyResults[0].Error).Error("Signature check failed: ")
|
||||
continue
|
||||
}
|
||||
|
||||
switch pdu.Type() {
|
||||
case spec.MRoomCreate:
|
||||
case spec.MRoomMember:
|
||||
var membership gomatrixserverlib.MemberContent
|
||||
err = json.Unmarshal(pdu.Content(), &membership)
|
||||
switch {
|
||||
case err != nil:
|
||||
util.GetLogger(req.Context()).Errorf("m.room.member event (%s) content invalid: %v", pdu.EventID(), pdu.Content())
|
||||
continue
|
||||
case membership.Membership == spec.Join:
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
|
||||
}
|
||||
}
|
||||
if !cfg.Matrix.IsLocalServerName(pdu.RoomID().Domain()) {
|
||||
queryReq := roomserverAPI.QueryMembershipForUserRequest{
|
||||
RoomID: pdu.RoomID().String(),
|
||||
UserID: *deviceUserID,
|
||||
}
|
||||
var queryRes roomserverAPI.QueryMembershipForUserResponse
|
||||
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if !queryRes.IsInRoom {
|
||||
// This is a join event to a remote room
|
||||
// TODO: cryptoIDs - figure out how to obtain unsigned contents for outstanding federated invites
|
||||
joinReq := roomserverAPI.PerformJoinRequestCryptoIDs{
|
||||
RoomID: pdu.RoomID().String(),
|
||||
UserID: device.UserID,
|
||||
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||
ServerNames: []spec.ServerName{spec.ServerName(event.ViaServer)},
|
||||
JoinEvent: pdu,
|
||||
}
|
||||
err := rsAPI.PerformSendJoinCryptoIDs(req.Context(), &joinReq)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err)
|
||||
}
|
||||
continue // NOTE: don't send this event on to the roomserver
|
||||
}
|
||||
}
|
||||
case membership.Membership == spec.Invite:
|
||||
stateKey := pdu.StateKey()
|
||||
if stateKey == nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: spec.Forbidden("invalid state_key for membership event"),
|
||||
}
|
||||
}
|
||||
invitedUserID, err := rsAPI.QueryUserIDForSender(req.Context(), pdu.RoomID(), spec.SenderID(*stateKey))
|
||||
if err != nil || invitedUserID == nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: spec.NotFound("cannot find userID for invite event"),
|
||||
}
|
||||
}
|
||||
if !cfg.Matrix.IsLocalServerName(spec.ServerName(invitedUserID.Domain())) {
|
||||
inviteReq := roomserverAPI.PerformInviteRequestCryptoIDs{
|
||||
RoomID: pdu.RoomID().String(),
|
||||
UserID: *invitedUserID,
|
||||
InviteEvent: pdu,
|
||||
}
|
||||
err := rsAPI.PerformSendInviteCryptoIDs(req.Context(), &inviteReq)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err)
|
||||
}
|
||||
continue // NOTE: don't send this event on to the roomserver
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: cryptoIDs - does it matter which order these are added?
|
||||
// yes - if the events are for room creation.
|
||||
// Could make this a client requirement? ie. events are processed based on the order they appear
|
||||
// We need to check event validity after processing each event.
|
||||
// ie. what if the client changes power levels that disallow further events they sent?
|
||||
// We should be doing this already as part of `SendInputRoomEvents`, but how should we pass this
|
||||
// failure back to the client?
|
||||
|
||||
var transactionID *roomserverAPI.TransactionID
|
||||
if txnID != nil {
|
||||
transactionID = &roomserverAPI.TransactionID{
|
||||
SessionID: device.SessionID, TransactionID: *txnID,
|
||||
}
|
||||
}
|
||||
|
||||
inputs = append(inputs, roomserverAPI.InputRoomEvent{
|
||||
Kind: roomserverAPI.KindNew,
|
||||
Event: &types.HeaderedEvent{PDU: pdu},
|
||||
Origin: userID.Domain(),
|
||||
// TODO: cryptoIDs - what to do with this field?
|
||||
// should probably generate this based on the event type being sent?
|
||||
//SendAsServer: roomserverAPI.DoNotSendToOtherServers,
|
||||
TransactionID: transactionID,
|
||||
})
|
||||
}
|
||||
|
||||
startedSubmittingEvents := time.Now()
|
||||
// send the events to the roomserver
|
||||
if err := roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, userID.Domain(), inputs, false); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{Err: err.Error()},
|
||||
}
|
||||
}
|
||||
timeToSubmitEvents := time.Since(startedSubmittingEvents)
|
||||
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvents.Milliseconds()))
|
||||
|
||||
// Add response to transactionsCache
|
||||
if txnID != nil {
|
||||
// TODO : cryptoIDs - fix this
|
||||
//res := util.JSONResponse{
|
||||
// Code: http.StatusOK,
|
||||
// JSON: sendEventResponse{e.EventID()},
|
||||
//}
|
||||
//txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
}
|
||||
}
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
@ -44,6 +45,11 @@ type sendEventResponse struct {
|
|||
EventID string `json:"event_id"`
|
||||
}
|
||||
|
||||
type sendEventResponseCryptoIDs struct {
|
||||
EventID string `json:"event_id"`
|
||||
PDU json.RawMessage `json:"pdu"`
|
||||
}
|
||||
|
||||
var (
|
||||
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
|
||||
)
|
||||
|
@ -94,7 +100,7 @@ func SendEvent(
|
|||
}
|
||||
|
||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil {
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
|
||||
parsedRoomID, innerErr := spec.NewRoomID(roomID)
|
||||
if innerErr != nil {
|
||||
return util.JSONResponse{
|
||||
|
@ -148,7 +154,7 @@ func SendEvent(
|
|||
}
|
||||
|
||||
// for power level events we need to replace the userID with the pseudoID
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels {
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
|
||||
err = updatePowerLevels(req, r, roomID, rsAPI)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
|
@ -166,7 +172,7 @@ func SendEvent(
|
|||
}
|
||||
}
|
||||
|
||||
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime)
|
||||
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime, false)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
@ -262,6 +268,156 @@ func SendEvent(
|
|||
return res
|
||||
}
|
||||
|
||||
// SendEventCryptoIDs implements:
|
||||
//
|
||||
// /rooms/{roomID}/send/{eventType}
|
||||
// /rooms/{roomID}/send/{eventType}/{txnID}
|
||||
// /rooms/{roomID}/state/{eventType}/{stateKey}
|
||||
//
|
||||
// nolint: gocyclo
|
||||
func SendEventCryptoIDs(
|
||||
req *http.Request,
|
||||
device *userapi.Device,
|
||||
roomID, eventType string, txnID, stateKey *string,
|
||||
cfg *config.ClientAPI,
|
||||
rsAPI api.ClientRoomserverAPI,
|
||||
txnCache *transactions.Cache,
|
||||
) util.JSONResponse {
|
||||
roomVersion, err := rsAPI.QueryRoomVersionForRoom(req.Context(), roomID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.UnsupportedRoomVersion(err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
if txnID != nil {
|
||||
// Try to fetch response from transactionsCache
|
||||
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
|
||||
return *res
|
||||
}
|
||||
}
|
||||
|
||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
|
||||
parsedRoomID, innerErr := spec.NewRoomID(roomID)
|
||||
if innerErr != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam("invalid room ID"),
|
||||
}
|
||||
}
|
||||
|
||||
newStateKey, innerErr := synctypes.FromClientStateKey(*parsedRoomID, *stateKey, func(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
|
||||
return rsAPI.QuerySenderIDForUser(req.Context(), roomID, userID)
|
||||
})
|
||||
if innerErr != nil {
|
||||
// TODO: work out better logic for failure cases (e.g. sender ID not found)
|
||||
util.GetLogger(req.Context()).WithError(innerErr).Error("synctypes.FromClientStateKey failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.Unknown("internal server error"),
|
||||
}
|
||||
}
|
||||
stateKey = newStateKey
|
||||
}
|
||||
|
||||
var r map[string]interface{} // must be a JSON object
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
if stateKey != nil {
|
||||
// If the existing/new state content are equal, return the existing event_id, making the request idempotent.
|
||||
if resp := stateEqual(req.Context(), rsAPI, eventType, *stateKey, roomID, r); resp != nil {
|
||||
return *resp
|
||||
}
|
||||
}
|
||||
|
||||
startedGeneratingEvent := time.Now()
|
||||
|
||||
// If we're sending a membership update, make sure to strip the authorised
|
||||
// via key if it is present, otherwise other servers won't be able to auth
|
||||
// the event if the room is set to the "restricted" join rule.
|
||||
if eventType == spec.MRoomMember {
|
||||
delete(r, "join_authorised_via_users_server")
|
||||
}
|
||||
|
||||
// for power level events we need to replace the userID with the pseudoID
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
|
||||
err = updatePowerLevels(req, r, roomID, rsAPI)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{Err: err.Error()},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
evTime, err := httputil.ParseTSParam(req)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam(err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime, true)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
timeToGenerateEvent := time.Since(startedGeneratingEvent)
|
||||
|
||||
// validate that the aliases exists
|
||||
if eventType == spec.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" {
|
||||
aliasReq := api.AliasEvent{}
|
||||
if err = json.Unmarshal(e.Content(), &aliasReq); err != nil {
|
||||
return util.ErrorResponse(fmt.Errorf("unable to parse alias event: %w", err))
|
||||
}
|
||||
if !aliasReq.Valid() {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam("Request contains invalid aliases."),
|
||||
}
|
||||
}
|
||||
aliasRes := &api.GetAliasesForRoomIDResponse{}
|
||||
if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
var found int
|
||||
requestAliases := append(aliasReq.AltAliases, aliasReq.Alias)
|
||||
for _, alias := range aliasRes.Aliases {
|
||||
for _, altAlias := range requestAliases {
|
||||
if altAlias == alias {
|
||||
found++
|
||||
}
|
||||
}
|
||||
}
|
||||
// check that we found at least the same amount of existing aliases as are in the request
|
||||
if aliasReq.Alias != "" && found < len(requestAliases) {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.BadAlias("No matching alias found."),
|
||||
}
|
||||
}
|
||||
}
|
||||
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
|
||||
|
||||
res := util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: sendEventResponseCryptoIDs{
|
||||
EventID: e.EventID(),
|
||||
PDU: json.RawMessage(e.JSON()),
|
||||
},
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error {
|
||||
users, ok := r["users"]
|
||||
if !ok {
|
||||
|
@ -329,6 +485,7 @@ func generateSendEvent(
|
|||
roomID, eventType string, stateKey *string,
|
||||
rsAPI api.ClientRoomserverAPI,
|
||||
evTime time.Time,
|
||||
cryptoIDs bool,
|
||||
) (gomatrixserverlib.PDU, *util.JSONResponse) {
|
||||
// parse the incoming http request
|
||||
fullUserID, err := spec.NewUserID(device.UserID, true)
|
||||
|
@ -376,13 +533,19 @@ func generateSendEvent(
|
|||
}
|
||||
}
|
||||
|
||||
identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID)
|
||||
if err != nil {
|
||||
var identity fclient.SigningIdentity
|
||||
if !cryptoIDs {
|
||||
id, idErr := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID)
|
||||
if idErr != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
identity = id
|
||||
} else {
|
||||
identity.ServerName = spec.ServerName(*senderID)
|
||||
}
|
||||
|
||||
var queryRes api.QueryLatestEventsAndStateResponse
|
||||
e, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, &queryRes)
|
||||
|
|
|
@ -169,7 +169,7 @@ func SendServerNotice(
|
|||
PowerLevelContentOverride: pl,
|
||||
}
|
||||
|
||||
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, userAPI, rsAPI, asAPI, time.Now())
|
||||
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, userAPI, rsAPI, asAPI, time.Now(), false)
|
||||
|
||||
switch data := roomRes.JSON.(type) {
|
||||
case createRoomResponse:
|
||||
|
@ -215,7 +215,7 @@ func SendServerNotice(
|
|||
}
|
||||
if !membershipRes.IsInRoom {
|
||||
// re-invite the user
|
||||
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
|
||||
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now(), false)
|
||||
if err != nil {
|
||||
return res
|
||||
}
|
||||
|
@ -228,7 +228,7 @@ func SendServerNotice(
|
|||
"body": r.Content.Body,
|
||||
"msgtype": r.Content.MsgType,
|
||||
}
|
||||
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now())
|
||||
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now(), false)
|
||||
if resErr != nil {
|
||||
logrus.Errorf("failed to send message: %+v", resErr)
|
||||
return *resErr
|
||||
|
|
|
@ -198,6 +198,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
|||
// state to see if there is an event with that type and state key, if there
|
||||
// is then (by default) we return the content, otherwise a 404.
|
||||
// If eventFormat=true, sends the whole event else just the content.
|
||||
// nolint:gocyclo
|
||||
func OnIncomingStateTypeRequest(
|
||||
ctx context.Context, device *userapi.Device, rsAPI api.ClientRoomserverAPI,
|
||||
roomID, evType, stateKey string, eventFormat bool,
|
||||
|
@ -214,7 +215,7 @@ func OnIncomingStateTypeRequest(
|
|||
}
|
||||
|
||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs || roomVer == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
parsedRoomID, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
|
|
|
@ -58,12 +58,18 @@ type RoomserverFederationAPI interface {
|
|||
PerformDirectoryLookup(ctx context.Context, request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse) error
|
||||
// Handle an instruction to make_join & send_join with a remote server.
|
||||
PerformJoin(ctx context.Context, request *PerformJoinRequest, response *PerformJoinResponse)
|
||||
PerformMakeJoin(ctx context.Context, request *PerformJoinRequest) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error)
|
||||
PerformSendJoin(ctx context.Context, request *PerformSendJoinRequestCryptoIDs, response *PerformJoinResponse)
|
||||
// Handle an instruction to make_leave & send_leave with a remote server.
|
||||
PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error
|
||||
PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse, cryptoIDs bool) error
|
||||
// Handle sending an invite to a remote server.
|
||||
SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
||||
// Handle sending an invite to a remote server.
|
||||
SendInviteV3(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
||||
// Handle sending an invite to a remote server.
|
||||
MakeInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
||||
// Handle sending an invite to a remote server.
|
||||
SendInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.PDU, invitee spec.UserID, version gomatrixserverlib.RoomVersion) error
|
||||
// Handle an instruction to peek a room on a remote server.
|
||||
PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error
|
||||
// Query the server names of the joined hosts in a room.
|
||||
|
@ -168,6 +174,15 @@ type PerformJoinRequest struct {
|
|||
Unsigned map[string]interface{} `json:"unsigned"`
|
||||
}
|
||||
|
||||
type PerformSendJoinRequestCryptoIDs struct {
|
||||
RoomID string `json:"room_id"`
|
||||
UserID string `json:"user_id"`
|
||||
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
|
||||
ServerNames types.ServerNames `json:"server_names"`
|
||||
Unsigned map[string]interface{} `json:"unsigned"`
|
||||
Event gomatrixserverlib.PDU
|
||||
}
|
||||
|
||||
type PerformJoinResponse struct {
|
||||
JoinedVia spec.ServerName
|
||||
LastError *gomatrix.HTTPError
|
||||
|
|
|
@ -239,6 +239,319 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
|||
return nil
|
||||
}
|
||||
|
||||
// PerformMakeJoin implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformMakeJoin(
|
||||
ctx context.Context,
|
||||
request *api.PerformJoinRequest,
|
||||
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||
// Check that a join isn't already in progress for this user/room.
|
||||
j := federatedJoin{request.UserID, request.RoomID}
|
||||
if _, found := r.joins.Load(j); found {
|
||||
return nil, "", "", &gomatrix.HTTPError{
|
||||
Code: 429,
|
||||
Message: `{
|
||||
"errcode": "M_LIMIT_EXCEEDED",
|
||||
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
|
||||
}`, // TODO: Why do none of our error types play nicely with each other?
|
||||
}
|
||||
}
|
||||
r.joins.Store(j, nil)
|
||||
defer r.joins.Delete(j)
|
||||
|
||||
// Deduplicate the server names we were provided but keep the ordering
|
||||
// as this encodes useful information about which servers are most likely
|
||||
// to respond.
|
||||
seenSet := make(map[spec.ServerName]bool)
|
||||
var uniqueList []spec.ServerName
|
||||
for _, srv := range request.ServerNames {
|
||||
if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) {
|
||||
continue
|
||||
}
|
||||
seenSet[srv] = true
|
||||
uniqueList = append(uniqueList, srv)
|
||||
}
|
||||
request.ServerNames = uniqueList
|
||||
|
||||
// Try each server that we were provided until we land on one that
|
||||
// successfully completes the make-join send-join dance.
|
||||
var lastErr error
|
||||
for _, serverName := range request.ServerNames {
|
||||
var joinEvent gomatrixserverlib.PDU
|
||||
var roomVersion gomatrixserverlib.RoomVersion
|
||||
var err error
|
||||
if joinEvent, roomVersion, _, err = r.performMakeJoinUsingServer(
|
||||
ctx,
|
||||
request.RoomID,
|
||||
request.UserID,
|
||||
request.Content,
|
||||
serverName,
|
||||
); err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"server_name": serverName,
|
||||
"room_id": request.RoomID,
|
||||
}).Warnf("Failed to join room through server")
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// We're all good.
|
||||
return joinEvent, roomVersion, serverName, err
|
||||
}
|
||||
|
||||
// If we reach here then we didn't complete a join for some reason.
|
||||
var httpErr gomatrix.HTTPError
|
||||
var lastError *gomatrix.HTTPError
|
||||
if ok := errors.As(lastErr, &httpErr); ok {
|
||||
httpErr.Message = string(httpErr.Contents)
|
||||
lastError = &httpErr
|
||||
} else {
|
||||
lastError = &gomatrix.HTTPError{
|
||||
Code: 0,
|
||||
WrappedError: nil,
|
||||
Message: "Unknown HTTP error",
|
||||
}
|
||||
if lastErr != nil {
|
||||
lastError.Message = lastErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Errorf(
|
||||
"failed to join user %q to room %q through %d server(s): last error %s",
|
||||
request.UserID, request.RoomID, len(request.ServerNames), lastError,
|
||||
)
|
||||
return nil, "", "", lastError
|
||||
}
|
||||
|
||||
func (r *FederationInternalAPI) performMakeJoinUsingServer(
|
||||
ctx context.Context,
|
||||
roomID, userID string,
|
||||
content map[string]interface{},
|
||||
serverName spec.ServerName,
|
||||
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.SenderID, error) {
|
||||
if !r.shouldAttemptDirectFederation(serverName) {
|
||||
return nil, "", "", fmt.Errorf("relay servers have no meaningful response for join.")
|
||||
}
|
||||
|
||||
user, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
room, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
joinInput := gomatrixserverlib.PerformMakeJoinInput{
|
||||
UserID: user,
|
||||
RoomID: room,
|
||||
ServerName: serverName,
|
||||
Content: content,
|
||||
PrivateKey: r.cfg.Matrix.PrivateKey,
|
||||
KeyID: r.cfg.Matrix.KeyID,
|
||||
KeyRing: r.keyRing,
|
||||
GetOrCreateSenderID: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) {
|
||||
// assign a roomNID, otherwise we can't create a private key for the user
|
||||
_, nidErr := r.rsAPI.AssignRoomNID(ctx, roomID, gomatrixserverlib.RoomVersion(roomVersion))
|
||||
if nidErr != nil {
|
||||
return "", nil, nidErr
|
||||
}
|
||||
key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
|
||||
if keyErr != nil {
|
||||
return "", nil, keyErr
|
||||
}
|
||||
return spec.SenderIDFromPseudoIDKey(key), key, nil
|
||||
},
|
||||
}
|
||||
joinEvent, version, senderID, joinErr := gomatrixserverlib.PerformMakeJoin(ctx, r, joinInput)
|
||||
|
||||
if joinErr != nil {
|
||||
if !joinErr.Reachable {
|
||||
r.statistics.ForServer(joinErr.ServerName).Failure()
|
||||
} else {
|
||||
r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect)
|
||||
}
|
||||
return nil, "", "", joinErr.Err
|
||||
}
|
||||
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
|
||||
if joinEvent == nil {
|
||||
return nil, "", "", fmt.Errorf("Received nil joinEvent response from gomatrixserverlib.PerformJoin")
|
||||
}
|
||||
|
||||
return joinEvent, version, senderID, nil
|
||||
}
|
||||
|
||||
// PerformSendJoin implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformSendJoin(
|
||||
ctx context.Context,
|
||||
request *api.PerformSendJoinRequestCryptoIDs,
|
||||
response *api.PerformJoinResponse,
|
||||
) {
|
||||
// Check that a join isn't already in progress for this user/room.
|
||||
j := federatedJoin{request.UserID, request.RoomID}
|
||||
if _, found := r.joins.Load(j); found {
|
||||
response.LastError = &gomatrix.HTTPError{
|
||||
Code: 429,
|
||||
Message: `{
|
||||
"errcode": "M_LIMIT_EXCEEDED",
|
||||
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
|
||||
}`,
|
||||
}
|
||||
return
|
||||
}
|
||||
r.joins.Store(j, nil)
|
||||
defer r.joins.Delete(j)
|
||||
|
||||
// Deduplicate the server names we were provided but keep the ordering
|
||||
// as this encodes useful information about which servers are most likely
|
||||
// to respond.
|
||||
seenSet := make(map[spec.ServerName]bool)
|
||||
var uniqueList []spec.ServerName
|
||||
for _, srv := range request.ServerNames {
|
||||
if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) {
|
||||
continue
|
||||
}
|
||||
seenSet[srv] = true
|
||||
uniqueList = append(uniqueList, srv)
|
||||
}
|
||||
request.ServerNames = uniqueList
|
||||
|
||||
// Try each server that we were provided until we land on one that
|
||||
// successfully completes the make-join send-join dance.
|
||||
var lastErr error
|
||||
for _, serverName := range request.ServerNames {
|
||||
if err := r.performSendJoinUsingServer(
|
||||
ctx,
|
||||
request.RoomID,
|
||||
request.UserID,
|
||||
request.Unsigned,
|
||||
request.Event,
|
||||
serverName,
|
||||
); err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"server_name": serverName,
|
||||
"room_id": request.RoomID,
|
||||
}).Warnf("Failed to join room through server")
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// We're all good.
|
||||
return
|
||||
}
|
||||
|
||||
// If we reach here then we didn't complete a join for some reason.
|
||||
var httpErr gomatrix.HTTPError
|
||||
var lastError *gomatrix.HTTPError
|
||||
if ok := errors.As(lastErr, &httpErr); ok {
|
||||
httpErr.Message = string(httpErr.Contents)
|
||||
lastError = &httpErr
|
||||
} else {
|
||||
lastError = &gomatrix.HTTPError{
|
||||
Code: 0,
|
||||
WrappedError: nil,
|
||||
Message: "Unknown HTTP error",
|
||||
}
|
||||
if lastErr != nil {
|
||||
lastError.Message = lastErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Errorf(
|
||||
"failed to join user %q to room %q through %d server(s): last error %s",
|
||||
request.UserID, request.RoomID, len(request.ServerNames), lastError,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *FederationInternalAPI) performSendJoinUsingServer(
|
||||
ctx context.Context,
|
||||
roomID, userID string,
|
||||
unsigned map[string]interface{},
|
||||
event gomatrixserverlib.PDU,
|
||||
serverName spec.ServerName,
|
||||
) error {
|
||||
user, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
room, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *room, *user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
joinInput := gomatrixserverlib.PerformSendJoinInput{
|
||||
RoomID: room,
|
||||
ServerName: serverName,
|
||||
Unsigned: unsigned,
|
||||
Origin: user.Domain(),
|
||||
SenderID: *senderID,
|
||||
KeyRing: r.keyRing,
|
||||
Event: event,
|
||||
RoomVersion: event.Version(),
|
||||
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}),
|
||||
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error {
|
||||
storeUserID, userErr := spec.NewUserID(userIDRaw, true)
|
||||
if userErr != nil {
|
||||
return userErr
|
||||
}
|
||||
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
||||
},
|
||||
}
|
||||
response, joinErr := gomatrixserverlib.PerformSendJoin(ctx, r, joinInput)
|
||||
if joinErr != nil {
|
||||
if !joinErr.Reachable {
|
||||
r.statistics.ForServer(joinErr.ServerName).Failure()
|
||||
} else {
|
||||
r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect)
|
||||
}
|
||||
return joinErr.Err
|
||||
}
|
||||
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
|
||||
if response == nil {
|
||||
return fmt.Errorf("Received nil response from gomatrixserverlib.PerformSendJoin")
|
||||
}
|
||||
|
||||
// We need to immediately update our list of joined hosts for this room now as we are technically
|
||||
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
|
||||
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
|
||||
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
|
||||
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
|
||||
// The events are trusted now as we performed auth checks above.
|
||||
joinedHosts, err := consumers.JoinedHostsFromEvents(ctx, response.StateSnapshot.GetStateEvents().TrustedEvents(response.JoinEvent.Version(), false), r.rsAPI)
|
||||
if err != nil {
|
||||
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
|
||||
}
|
||||
|
||||
logrus.WithField("room", roomID).Infof("Joined federated room with %d hosts", len(joinedHosts))
|
||||
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
|
||||
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
|
||||
}
|
||||
|
||||
// TODO: Can I change this to not take respState but instead just take an opaque list of events?
|
||||
if err = roomserverAPI.SendEventWithState(
|
||||
context.Background(),
|
||||
r.rsAPI,
|
||||
user.Domain(),
|
||||
roomserverAPI.KindNew,
|
||||
response.StateSnapshot,
|
||||
&types.HeaderedEvent{PDU: response.JoinEvent},
|
||||
serverName,
|
||||
nil,
|
||||
false,
|
||||
); err != nil {
|
||||
return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformOutboundPeekRequest implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformOutboundPeek(
|
||||
ctx context.Context,
|
||||
|
@ -433,6 +746,7 @@ func (r *FederationInternalAPI) PerformLeave(
|
|||
ctx context.Context,
|
||||
request *api.PerformLeaveRequest,
|
||||
response *api.PerformLeaveResponse,
|
||||
cryptoIDs bool,
|
||||
) (err error) {
|
||||
userID, err := spec.NewUserID(request.UserID, true)
|
||||
if err != nil {
|
||||
|
@ -649,6 +963,87 @@ func (r *FederationInternalAPI) SendInviteV3(
|
|||
return inviteEvent, nil
|
||||
}
|
||||
|
||||
// MakeInviteCryptoIDs implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) MakeInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) {
|
||||
validRoomID, err := spec.NewRoomID(event.RoomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verImpl, err := gomatrixserverlib.GetRoomVersion(version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(event.SenderID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO (devon): This should be allowed via a relay. Currently only transactions
|
||||
// can be sent to relays. Would need to extend relays to handle invites.
|
||||
if !r.shouldAttemptDirectFederation(invitee.Domain()) {
|
||||
return nil, fmt.Errorf("relay servers have no meaningful response for invite.")
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"user_id": invitee.String(),
|
||||
"room_id": event.RoomID,
|
||||
"room_version": version,
|
||||
"destination": invitee.Domain(),
|
||||
}).Info("Sending /send_invite")
|
||||
|
||||
inviteReq, err := fclient.NewInviteV3Request(event, version, strippedState)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gomatrixserverlib.NewInviteV3Request: %w", err)
|
||||
}
|
||||
|
||||
inviteRes, err := r.federation.MakeInviteCryptoIDs(ctx, inviter.Domain(), invitee.Domain(), inviteReq, invitee)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("r.federation.SendInviteCryptoIDs: failed to send invite: %w", err)
|
||||
}
|
||||
|
||||
inviteEvent, err := verImpl.NewEventFromUntrustedJSON(inviteRes.Event)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("r.federation.SendInviteCryptoIDs failed to decode event response: %w", err)
|
||||
}
|
||||
return inviteEvent, nil
|
||||
}
|
||||
|
||||
// SendInviteCryptoIDs implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) SendInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.PDU, invitee spec.UserID, version gomatrixserverlib.RoomVersion) error {
|
||||
validRoomID := event.RoomID()
|
||||
|
||||
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, validRoomID, event.SenderID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO (devon): This should be allowed via a relay. Currently only transactions
|
||||
// can be sent to relays. Would need to extend relays to handle invites.
|
||||
if !r.shouldAttemptDirectFederation(invitee.Domain()) {
|
||||
return fmt.Errorf("relay servers have no meaningful response for invite.")
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"user_id": invitee.String(),
|
||||
"room_id": event.RoomID,
|
||||
"room_version": version,
|
||||
"destination": invitee.Domain(),
|
||||
}).Info("Sending /make_invite")
|
||||
|
||||
inviteReq, err := fclient.NewSendInviteCryptoIDsRequest(event, version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.NewInviteV3Request: %w", err)
|
||||
}
|
||||
|
||||
err = r.federation.SendInviteCryptoIDs(ctx, inviter.Domain(), invitee.Domain(), inviteReq, invitee)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.federation.SendInviteCryptoIDs: failed to send invite: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformServersAlive implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformBroadcastEDU(
|
||||
ctx context.Context,
|
||||
|
|
14
go.mod
14
go.mod
|
@ -1,5 +1,9 @@
|
|||
module github.com/matrix-org/dendrite
|
||||
|
||||
//replace github.com/matrix-org/gomatrixserverlib => ../../gomatrixserverlib/crypto-ids/
|
||||
|
||||
//replace github.com/matrix-org/gomatrixserverlib => /src/gmsl/
|
||||
|
||||
require (
|
||||
github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2
|
||||
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
|
||||
|
@ -22,7 +26,7 @@ require (
|
|||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862
|
||||
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7
|
||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||
github.com/mattn/go-sqlite3 v1.14.17
|
||||
|
@ -42,12 +46,12 @@ require (
|
|||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.4.6
|
||||
go.uber.org/atomic v1.10.0
|
||||
golang.org/x/crypto v0.13.0
|
||||
golang.org/x/crypto v0.17.0
|
||||
golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819
|
||||
golang.org/x/image v0.5.0
|
||||
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
|
||||
golang.org/x/sync v0.3.0
|
||||
golang.org/x/term v0.12.0
|
||||
golang.org/x/term v0.15.0
|
||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gotest.tools/v3 v3.4.0
|
||||
|
@ -124,8 +128,8 @@ require (
|
|||
go.etcd.io/bbolt v1.3.6 // indirect
|
||||
golang.org/x/mod v0.12.0 // indirect
|
||||
golang.org/x/net v0.14.0 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
golang.org/x/tools v0.12.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
|
|
20
go.sum
20
go.sum
|
@ -208,8 +208,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
|||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4 h1:UuXfC7b29RBDfMdLmggeF3opu3XuGi8bNT9SKZtZc3I=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862 h1:Kuya3qas85ZvVVkuOpemwhgvdJbLojvwvt3xyJTp1dY=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 h1:6t8kJr8i1/1I5nNttw6nn1ryQJgzVlBmSGgPiiaTdw4=
|
||||
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7/go.mod h1:ReWMS/LoVnOiRAdq9sNUC2NZnd1mZkMNB52QhpTRWjg=
|
||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
||||
|
@ -354,8 +354,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
|
@ -418,19 +418,19 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
|
||||
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
|
|
@ -85,13 +85,21 @@ func BuildEvent(
|
|||
}
|
||||
builder := verImpl.NewEventBuilderFromProtoEvent(proto)
|
||||
|
||||
event, err := builder.Build(
|
||||
var event gomatrixserverlib.PDU
|
||||
if identity.PrivateKey != nil {
|
||||
event, err = builder.Build(
|
||||
evTime, identity.ServerName, identity.KeyID,
|
||||
identity.PrivateKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
event, err = builder.BuildWithoutSigning(evTime, identity.ServerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &types.HeaderedEvent{PDU: event}, nil
|
||||
}
|
||||
|
|
|
@ -92,6 +92,7 @@ type UserRoomPrivateKeyCreator interface {
|
|||
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
|
||||
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
|
||||
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
|
||||
ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||
}
|
||||
|
||||
type InputRoomEventsAPI interface {
|
||||
|
@ -222,6 +223,7 @@ type ClientRoomserverAPI interface {
|
|||
DefaultRoomVersionAPI
|
||||
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
|
||||
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
||||
InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
|
||||
QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
|
||||
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
|
||||
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||
|
@ -232,6 +234,7 @@ type ClientRoomserverAPI interface {
|
|||
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
|
||||
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
|
||||
|
||||
PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) ([]gomatrixserverlib.PDU, error)
|
||||
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
|
||||
// PerformRoomUpgrade upgrades a room to a newer version
|
||||
PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
|
||||
|
@ -241,9 +244,12 @@ type ClientRoomserverAPI interface {
|
|||
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
|
||||
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
|
||||
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
|
||||
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
|
||||
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
|
||||
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
|
||||
PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error
|
||||
PerformSendInviteCryptoIDs(ctx context.Context, req *PerformInviteRequestCryptoIDs) error
|
||||
PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error)
|
||||
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||
PerformPublish(ctx context.Context, req *PerformPublishRequest) error
|
||||
// PerformForget forgets a rooms history for a specific user
|
||||
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
|
||||
|
@ -305,7 +311,7 @@ type FederationRoomserverAPI interface {
|
|||
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
||||
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
|
||||
|
||||
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
|
||||
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||
// Query a given amount (or less) of events prior to a given set of events.
|
||||
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
||||
|
||||
|
|
|
@ -29,6 +29,8 @@ type PerformCreateRoomRequest struct {
|
|||
KeyID gomatrixserverlib.KeyID
|
||||
PrivateKey ed25519.PrivateKey
|
||||
EventTime time.Time
|
||||
|
||||
SenderID string
|
||||
}
|
||||
|
||||
type PerformJoinRequest struct {
|
||||
|
@ -38,6 +40,23 @@ type PerformJoinRequest struct {
|
|||
Content map[string]interface{} `json:"content"`
|
||||
ServerNames []spec.ServerName `json:"server_names"`
|
||||
Unsigned map[string]interface{} `json:"unsigned"`
|
||||
SenderID spec.SenderID
|
||||
}
|
||||
|
||||
type PerformJoinRequestCryptoIDs struct {
|
||||
RoomID string
|
||||
UserID string
|
||||
IsGuest bool
|
||||
Content map[string]interface{}
|
||||
ServerNames []spec.ServerName
|
||||
Unsigned map[string]interface{}
|
||||
JoinEvent gomatrixserverlib.PDU
|
||||
}
|
||||
|
||||
type PerformInviteRequestCryptoIDs struct {
|
||||
RoomID string
|
||||
UserID spec.UserID
|
||||
InviteEvent gomatrixserverlib.PDU
|
||||
}
|
||||
|
||||
type PerformLeaveRequest struct {
|
||||
|
|
|
@ -3,6 +3,7 @@ package internal
|
|||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -54,6 +55,7 @@ type RoomserverInternalAPI struct {
|
|||
ServerACLs *acls.ServerACLs
|
||||
fsAPI fsAPI.RoomserverFederationAPI
|
||||
asAPI asAPI.AppServiceInternalAPI
|
||||
usAPI userapi.RoomserverUserAPI
|
||||
NATSClient *nats.Conn
|
||||
JetStream nats.JetStreamContext
|
||||
Durable string
|
||||
|
@ -214,6 +216,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
||||
r.Leaver.UserAPI = userAPI
|
||||
r.Inputer.UserAPI = userAPI
|
||||
r.usAPI = userAPI
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
||||
|
@ -251,24 +254,27 @@ func (r *RoomserverInternalAPI) PerformCreateRoom(
|
|||
func (r *RoomserverInternalAPI) PerformInvite(
|
||||
ctx context.Context,
|
||||
req *api.PerformInviteRequest,
|
||||
) error {
|
||||
return r.Inviter.PerformInvite(ctx, req)
|
||||
cryptoIDs bool,
|
||||
) (gomatrixserverlib.PDU, error) {
|
||||
return r.Inviter.PerformInvite(ctx, req, cryptoIDs)
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) PerformLeave(
|
||||
ctx context.Context,
|
||||
req *api.PerformLeaveRequest,
|
||||
res *api.PerformLeaveResponse,
|
||||
) error {
|
||||
outputEvents, err := r.Leaver.PerformLeave(ctx, req, res)
|
||||
cryptoIDs bool,
|
||||
) (gomatrixserverlib.PDU, error) {
|
||||
outputEvents, leaveEvent, err := r.Leaver.PerformLeave(ctx, req, res, cryptoIDs)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if len(outputEvents) == 0 {
|
||||
return nil
|
||||
return leaveEvent, nil
|
||||
}
|
||||
return r.OutputProducer.ProduceRoomEvents(req.RoomID, outputEvents)
|
||||
// TODO: cryptoIDs - what to do with this?
|
||||
return leaveEvent, r.OutputProducer.ProduceRoomEvents(req.RoomID, outputEvents)
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) PerformForget(
|
||||
|
@ -308,6 +314,10 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
|
|||
return err
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||
return r.usAPI.ClaimOneTimeCryptoID(ctx, roomID, userID)
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
|
||||
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
|
||||
if !ok {
|
||||
|
@ -330,6 +340,19 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
|
|||
ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(privKey)),
|
||||
}, nil
|
||||
}
|
||||
if roomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
sender, err := r.QuerySenderIDForUser(ctx, roomID, senderID)
|
||||
if err != nil {
|
||||
return fclient.SigningIdentity{}, err
|
||||
} else if sender == nil {
|
||||
return fclient.SigningIdentity{}, fmt.Errorf("no sender ID for %s in %s", senderID.String(), roomID.String())
|
||||
}
|
||||
return fclient.SigningIdentity{
|
||||
PrivateKey: nil,
|
||||
KeyID: "ed25519:1",
|
||||
ServerName: spec.ServerName(*sender),
|
||||
}, nil
|
||||
}
|
||||
identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain())
|
||||
if err != nil {
|
||||
return fclient.SigningIdentity{}, err
|
||||
|
|
|
@ -445,7 +445,7 @@ func (r *Inputer) processRoomEvent(
|
|||
}
|
||||
|
||||
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
|
||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember {
|
||||
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && event.Type() == spec.MRoomMember {
|
||||
mapping := gomatrixserverlib.MemberContent{}
|
||||
if err = json.Unmarshal(event.Content(), &mapping); err != nil {
|
||||
return err
|
||||
|
|
|
@ -179,7 +179,7 @@ func (r *Admin) PerformAdminEvacuateUser(
|
|||
Leaver: *fullUserID,
|
||||
}
|
||||
leaveRes := &api.PerformLeaveResponse{}
|
||||
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)
|
||||
outputEvents, _, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -44,6 +44,417 @@ type Creator struct {
|
|||
RSAPI api.RoomserverInternalAPI
|
||||
}
|
||||
|
||||
// PerformCreateRoomCryptoIDs handles all the steps necessary to create a new room.
|
||||
// nolint: gocyclo
|
||||
func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) ([]gomatrixserverlib.PDU, error) {
|
||||
verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, spec.BadJSON("unknown room version")
|
||||
}
|
||||
|
||||
createContent := map[string]interface{}{}
|
||||
if len(createRequest.CreationContent) > 0 {
|
||||
if err = json.Unmarshal(createRequest.CreationContent, &createContent); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed")
|
||||
return nil, spec.BadJSON("invalid create content")
|
||||
}
|
||||
}
|
||||
|
||||
senderID := spec.SenderID(createRequest.SenderID)
|
||||
|
||||
// TODO: cryptoIDs - should we be assigning a room NID yet?
|
||||
_, err = c.DB.AssignRoomNID(ctx, roomID, createRequest.RoomVersion)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("failed to assign roomNID")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String())
|
||||
bytes := spec.Base64Bytes{}
|
||||
err = bytes.Decode(string(senderID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(bytes) != ed25519.PublicKeySize {
|
||||
return nil, spec.BadJSON("SenderID is not a valid ed25519 public key")
|
||||
}
|
||||
|
||||
keyErr := c.RSAPI.StoreUserRoomPublicKey(ctx, senderID, userID, roomID)
|
||||
if keyErr != nil {
|
||||
util.GetLogger(ctx).WithError(keyErr).Error("StoreUserRoomPublicKey failed")
|
||||
return nil, spec.InternalServerError{Err: keyErr.Error()}
|
||||
}
|
||||
}
|
||||
|
||||
createContent["creator"] = senderID
|
||||
createContent["room_version"] = createRequest.RoomVersion
|
||||
powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
|
||||
joinRuleContent := gomatrixserverlib.JoinRuleContent{
|
||||
JoinRule: spec.Invite,
|
||||
}
|
||||
historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{
|
||||
HistoryVisibility: historyVisibilityShared,
|
||||
}
|
||||
|
||||
if createRequest.PowerLevelContentOverride != nil {
|
||||
// Merge powerLevelContentOverride fields by unmarshalling it atop the defaults
|
||||
err = json.Unmarshal(createRequest.PowerLevelContentOverride, &powerLevelContent)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed")
|
||||
return nil, spec.BadJSON("malformed power_level_content_override")
|
||||
}
|
||||
}
|
||||
|
||||
var guestsCanJoin bool
|
||||
switch createRequest.StatePreset {
|
||||
case spec.PresetPrivateChat:
|
||||
joinRuleContent.JoinRule = spec.Invite
|
||||
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
|
||||
guestsCanJoin = true
|
||||
case spec.PresetTrustedPrivateChat:
|
||||
joinRuleContent.JoinRule = spec.Invite
|
||||
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
|
||||
for _, invitee := range createRequest.InvitedUsers {
|
||||
powerLevelContent.Users[invitee] = 100
|
||||
}
|
||||
guestsCanJoin = true
|
||||
case spec.PresetPublicChat:
|
||||
joinRuleContent.JoinRule = spec.Public
|
||||
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
|
||||
}
|
||||
|
||||
createEvent := gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomCreate,
|
||||
Content: createContent,
|
||||
}
|
||||
powerLevelEvent := gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomPowerLevels,
|
||||
Content: powerLevelContent,
|
||||
}
|
||||
joinRuleEvent := gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomJoinRules,
|
||||
Content: joinRuleContent,
|
||||
}
|
||||
historyVisibilityEvent := gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomHistoryVisibility,
|
||||
Content: historyVisibilityContent,
|
||||
}
|
||||
membershipEvent := gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomMember,
|
||||
StateKey: string(senderID),
|
||||
}
|
||||
|
||||
memberContent := gomatrixserverlib.MemberContent{
|
||||
Membership: spec.Join,
|
||||
DisplayName: createRequest.UserDisplayName,
|
||||
AvatarURL: createRequest.UserAvatarURL,
|
||||
}
|
||||
|
||||
// If we are creating a room with pseudo IDs, create and sign the MXIDMapping
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
mapping := &gomatrixserverlib.MXIDMapping{
|
||||
UserRoomKey: senderID,
|
||||
UserID: userID.String(),
|
||||
}
|
||||
|
||||
identity, idErr := c.Cfg.Matrix.SigningIdentityFor(userID.Domain()) // we MUST use the server signing mxid_mapping
|
||||
if idErr != nil {
|
||||
logrus.WithError(idErr).WithField("domain", userID.Domain()).Error("unable to find signing identity for domain")
|
||||
return nil, spec.InternalServerError{Err: idErr.Error()}
|
||||
}
|
||||
|
||||
// Sign the mapping with the server identity
|
||||
if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil {
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
memberContent.MXIDMapping = mapping
|
||||
}
|
||||
membershipEvent.Content = memberContent
|
||||
|
||||
var nameEvent *gomatrixserverlib.FledglingEvent
|
||||
var topicEvent *gomatrixserverlib.FledglingEvent
|
||||
var guestAccessEvent *gomatrixserverlib.FledglingEvent
|
||||
var aliasEvent *gomatrixserverlib.FledglingEvent
|
||||
|
||||
if createRequest.RoomName != "" {
|
||||
nameEvent = &gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomName,
|
||||
Content: eventutil.NameContent{
|
||||
Name: createRequest.RoomName,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if createRequest.Topic != "" {
|
||||
topicEvent = &gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomTopic,
|
||||
Content: eventutil.TopicContent{
|
||||
Topic: createRequest.Topic,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if guestsCanJoin {
|
||||
guestAccessEvent = &gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomGuestAccess,
|
||||
Content: eventutil.GuestAccessContent{
|
||||
GuestAccess: "can_join",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var roomAlias string
|
||||
if createRequest.RoomAliasName != "" {
|
||||
roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, userID.Domain())
|
||||
// check it's free
|
||||
// TODO: This races but is better than nothing
|
||||
hasAliasReq := api.GetRoomIDForAliasRequest{
|
||||
Alias: roomAlias,
|
||||
IncludeAppservices: false,
|
||||
}
|
||||
|
||||
var aliasResp api.GetRoomIDForAliasResponse
|
||||
err = c.RSAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
if aliasResp.RoomID != "" {
|
||||
return nil, spec.RoomInUse("Room ID already exists.")
|
||||
}
|
||||
|
||||
aliasEvent = &gomatrixserverlib.FledglingEvent{
|
||||
Type: spec.MRoomCanonicalAlias,
|
||||
Content: eventutil.CanonicalAlias{
|
||||
Alias: roomAlias,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var initialStateEvents []gomatrixserverlib.FledglingEvent
|
||||
for i := range createRequest.InitialState {
|
||||
if createRequest.InitialState[i].StateKey != "" {
|
||||
initialStateEvents = append(initialStateEvents, createRequest.InitialState[i])
|
||||
continue
|
||||
}
|
||||
|
||||
switch createRequest.InitialState[i].Type {
|
||||
case spec.MRoomCreate:
|
||||
continue
|
||||
|
||||
case spec.MRoomPowerLevels:
|
||||
powerLevelEvent = createRequest.InitialState[i]
|
||||
|
||||
case spec.MRoomJoinRules:
|
||||
joinRuleEvent = createRequest.InitialState[i]
|
||||
|
||||
case spec.MRoomHistoryVisibility:
|
||||
historyVisibilityEvent = createRequest.InitialState[i]
|
||||
|
||||
case spec.MRoomGuestAccess:
|
||||
guestAccessEvent = &createRequest.InitialState[i]
|
||||
|
||||
case spec.MRoomName:
|
||||
nameEvent = &createRequest.InitialState[i]
|
||||
|
||||
case spec.MRoomTopic:
|
||||
topicEvent = &createRequest.InitialState[i]
|
||||
|
||||
default:
|
||||
initialStateEvents = append(initialStateEvents, createRequest.InitialState[i])
|
||||
}
|
||||
}
|
||||
|
||||
// send events into the room in order of:
|
||||
// 1- m.room.create
|
||||
// 2- room creator join member
|
||||
// 3- m.room.power_levels
|
||||
// 4- m.room.join_rules
|
||||
// 5- m.room.history_visibility
|
||||
// 6- m.room.canonical_alias (opt)
|
||||
// 7- m.room.guest_access (opt)
|
||||
// 8- other initial state items
|
||||
// 9- m.room.name (opt)
|
||||
// 10- m.room.topic (opt)
|
||||
// 11- invite events (opt) - with is_direct flag if applicable TODO
|
||||
// 12- 3pid invite events (opt) TODO
|
||||
// This differs from Synapse slightly. Synapse would vary the ordering of 3-7
|
||||
// depending on if those events were in "initial_state" or not. This made it
|
||||
// harder to reason about, hence sticking to a strict static ordering.
|
||||
eventsToMake := []gomatrixserverlib.FledglingEvent{
|
||||
createEvent, membershipEvent, powerLevelEvent, joinRuleEvent, historyVisibilityEvent,
|
||||
}
|
||||
if guestAccessEvent != nil {
|
||||
eventsToMake = append(eventsToMake, *guestAccessEvent)
|
||||
}
|
||||
eventsToMake = append(eventsToMake, initialStateEvents...)
|
||||
if nameEvent != nil {
|
||||
eventsToMake = append(eventsToMake, *nameEvent)
|
||||
}
|
||||
if topicEvent != nil {
|
||||
eventsToMake = append(eventsToMake, *topicEvent)
|
||||
}
|
||||
if aliasEvent != nil {
|
||||
// TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room.
|
||||
// This means we might fail creating the alias but say the canonical alias is something that doesn't exist.
|
||||
eventsToMake = append(eventsToMake, *aliasEvent)
|
||||
}
|
||||
|
||||
var builtEvents []gomatrixserverlib.PDU
|
||||
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
for i, e := range eventsToMake {
|
||||
depth := i + 1 // depth starts at 1
|
||||
|
||||
builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
|
||||
SenderID: string(senderID),
|
||||
RoomID: roomID.String(),
|
||||
Type: e.Type,
|
||||
StateKey: &e.StateKey,
|
||||
Depth: int64(depth),
|
||||
})
|
||||
err = builder.SetContent(e.Content)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
if i > 0 {
|
||||
builder.PrevEvents = []string{builtEvents[i-1].EventID()}
|
||||
}
|
||||
var ev gomatrixserverlib.PDU
|
||||
if err = builder.AddAuthEvents(&authEvents); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
|
||||
ev, err = builder.BuildWithoutSigning(createRequest.EventTime, userID.Domain())
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
|
||||
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
|
||||
// Add the event to the list of auth events
|
||||
builtEvents = append(builtEvents, ev)
|
||||
err = authEvents.AddEvent(ev)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(#269): Reserve room alias while we create the room. This stops us
|
||||
// from creating the room but still failing due to the alias having already
|
||||
// been taken.
|
||||
if roomAlias != "" {
|
||||
aliasAlreadyExists, aliasErr := c.RSAPI.SetRoomAlias(ctx, senderID, roomID, roomAlias)
|
||||
if aliasErr != nil {
|
||||
util.GetLogger(ctx).WithError(aliasErr).Error("aliasAPI.SetRoomAlias failed")
|
||||
return nil, spec.InternalServerError{Err: aliasErr.Error()}
|
||||
}
|
||||
|
||||
if aliasAlreadyExists {
|
||||
return nil, spec.RoomInUse("Room alias already exists.")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: cryptoIDs - this shouldn't really be done until the client calls /sendPDUs with these events
|
||||
// But that would require the visibility setting also being passed along
|
||||
if createRequest.Visibility == spec.Public {
|
||||
// expose this room in the published room list
|
||||
if err = c.RSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
|
||||
RoomID: roomID.String(),
|
||||
Visibility: spec.Public,
|
||||
}); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("failed to publish room")
|
||||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
}
|
||||
|
||||
// If this is a direct message then we should invite the participants.
|
||||
//if len(createRequest.InvitedUsers) > 0 {
|
||||
// Build some stripped state for the invite.
|
||||
//var globalStrippedState []gomatrixserverlib.InviteStrippedState
|
||||
//for _, event := range builtEvents {
|
||||
// // Chosen events from the spec:
|
||||
// // https://spec.matrix.org/v1.3/client-server-api/#stripped-state
|
||||
// switch event.Type() {
|
||||
// case spec.MRoomCreate:
|
||||
// fallthrough
|
||||
// case spec.MRoomName:
|
||||
// fallthrough
|
||||
// case spec.MRoomAvatar:
|
||||
// fallthrough
|
||||
// case spec.MRoomTopic:
|
||||
// fallthrough
|
||||
// case spec.MRoomCanonicalAlias:
|
||||
// fallthrough
|
||||
// case spec.MRoomEncryption:
|
||||
// fallthrough
|
||||
// case spec.MRoomMember:
|
||||
// fallthrough
|
||||
// case spec.MRoomJoinRules:
|
||||
// ev := event
|
||||
// globalStrippedState = append(
|
||||
// globalStrippedState,
|
||||
// gomatrixserverlib.NewInviteStrippedState(ev),
|
||||
// )
|
||||
// }
|
||||
//}
|
||||
|
||||
// Process the invites.
|
||||
//for _, invitee := range createRequest.InvitedUsers {
|
||||
//inviteeUserID, userIDErr := spec.NewUserID(invitee, true)
|
||||
//if userIDErr != nil {
|
||||
// util.GetLogger(ctx).WithError(userIDErr).Error("invalid UserID")
|
||||
// return nil, spec.InternalServerError{}
|
||||
//}
|
||||
|
||||
// TODO: cryptoIDs - these shouldn't be here
|
||||
// instead we should return proto invite events?
|
||||
//err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||
// InviteInput: api.InviteInput{
|
||||
// RoomID: roomID,
|
||||
// Inviter: userID,
|
||||
// Invitee: *inviteeUserID,
|
||||
// DisplayName: createRequest.UserDisplayName,
|
||||
// AvatarURL: createRequest.UserAvatarURL,
|
||||
// Reason: "",
|
||||
// IsDirect: createRequest.IsDirect,
|
||||
// KeyID: createRequest.KeyID,
|
||||
// PrivateKey: createRequest.PrivateKey,
|
||||
// EventTime: createRequest.EventTime,
|
||||
// },
|
||||
// InviteRoomState: globalStrippedState,
|
||||
// SendAsServer: string(userID.Domain()),
|
||||
//})
|
||||
//switch e := err.(type) {
|
||||
//case api.ErrInvalidID:
|
||||
// return nil, spec.Unknown(e.Error())
|
||||
//case api.ErrNotAllowed:
|
||||
// return nil, spec.Forbidden(e.Error())
|
||||
//case nil:
|
||||
//default:
|
||||
// util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
|
||||
// sentry.CaptureException(err)
|
||||
// return nil, spec.InternalServerError{}
|
||||
//}
|
||||
//}
|
||||
//}
|
||||
|
||||
return builtEvents, nil
|
||||
}
|
||||
|
||||
// PerformCreateRoom handles all the steps necessary to create a new room.
|
||||
// nolint: gocyclo
|
||||
func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) {
|
||||
|
@ -501,7 +912,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
}
|
||||
|
||||
err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||
_, err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||
InviteInput: api.InviteInput{
|
||||
RoomID: roomID,
|
||||
Inviter: userID,
|
||||
|
@ -516,7 +927,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
},
|
||||
InviteRoomState: globalStrippedState,
|
||||
SendAsServer: string(userID.Domain()),
|
||||
})
|
||||
}, false)
|
||||
switch e := err.(type) {
|
||||
case api.ErrInvalidID:
|
||||
return "", &util.JSONResponse{
|
||||
|
|
|
@ -125,16 +125,17 @@ func (r *Inviter) ProcessInviteMembership(
|
|||
func (r *Inviter) PerformInvite(
|
||||
ctx context.Context,
|
||||
req *api.PerformInviteRequest,
|
||||
) error {
|
||||
cryptoIDs bool,
|
||||
) (gomatrixserverlib.PDU, error) {
|
||||
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
} else if senderID == nil {
|
||||
return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
|
||||
return nil, fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
|
||||
}
|
||||
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
|
@ -152,20 +153,20 @@ func (r *Inviter) PerformInvite(
|
|||
}
|
||||
|
||||
if err = proto.SetContent(content); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
|
||||
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
|
||||
return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
|
||||
}
|
||||
|
||||
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
|
||||
|
||||
signingKey := req.InviteInput.PrivateKey
|
||||
if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if !cryptoIDs && info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -222,6 +223,10 @@ func (r *Inviter) PerformInvite(
|
|||
}
|
||||
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
||||
},
|
||||
CryptoIDs: cryptoIDs,
|
||||
ClaimSenderID: func(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||
return r.RSAPI.ClaimOneTimeSenderIDForUser(ctx, roomID, userID)
|
||||
},
|
||||
}
|
||||
|
||||
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
||||
|
@ -229,12 +234,14 @@ func (r *Inviter) PerformInvite(
|
|||
switch e := err.(type) {
|
||||
case spec.MatrixError:
|
||||
if e.ErrCode == spec.ErrorForbidden {
|
||||
return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
|
||||
return nil, api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
|
||||
}
|
||||
}
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response gomatrixserverlib.PDU
|
||||
if !cryptoIDs {
|
||||
// Send the invite event to the roomserver input stream. This will
|
||||
// notify existing users in the room about the invite, update the
|
||||
// membership table and ensure that the event is ready and available
|
||||
|
@ -254,8 +261,11 @@ func (r *Inviter) PerformInvite(
|
|||
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
|
||||
if err := inputRes.Err(); err != nil {
|
||||
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
|
||||
return api.ErrNotAllowed{Err: err}
|
||||
return nil, api.ErrNotAllowed{Err: err}
|
||||
}
|
||||
} else {
|
||||
response = inviteEvent
|
||||
}
|
||||
|
||||
return nil
|
||||
return response, nil
|
||||
}
|
||||
|
|
|
@ -406,6 +406,380 @@ func (r *Joiner) performJoinRoomByID(
|
|||
return req.RoomIDOrAlias, userDomain, nil
|
||||
}
|
||||
|
||||
func (r *Joiner) PerformSendJoinCryptoIDs(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformJoinRequestCryptoIDs,
|
||||
) error {
|
||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomID,
|
||||
"user_id": req.UserID,
|
||||
"servers": req.ServerNames,
|
||||
})
|
||||
logger.Info("performing send join")
|
||||
res := fsAPI.PerformJoinResponse{}
|
||||
r.FSAPI.PerformSendJoin(ctx, &fsAPI.PerformSendJoinRequestCryptoIDs{
|
||||
RoomID: req.RoomID,
|
||||
UserID: req.UserID,
|
||||
ServerNames: req.ServerNames,
|
||||
Unsigned: req.Unsigned,
|
||||
Event: req.JoinEvent,
|
||||
}, &res)
|
||||
if res.LastError != nil {
|
||||
return res.LastError
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Joiner) PerformSendInviteCryptoIDs(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformInviteRequestCryptoIDs,
|
||||
) error {
|
||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomID,
|
||||
"user_id": req.UserID,
|
||||
})
|
||||
logger.Info("performing send invite")
|
||||
err := r.FSAPI.SendInviteCryptoIDs(ctx, req.InviteEvent, req.UserID, req.InviteEvent.Version())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformJoin handles joining matrix rooms, including over federation by talking to the federationapi.
|
||||
func (r *Joiner) PerformJoinCryptoIDs(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformJoinRequest,
|
||||
) (joinEvent gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error) {
|
||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomIDOrAlias,
|
||||
"user_id": req.UserID,
|
||||
"servers": req.ServerNames,
|
||||
})
|
||||
logger.Info("User requested to room join")
|
||||
join, roomID, version, serverName, err := r.makeJoinEvent(context.Background(), req)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to make join room event")
|
||||
sentry.CaptureException(err)
|
||||
return nil, "", "", "", err
|
||||
}
|
||||
|
||||
return join, roomID, version, serverName, nil
|
||||
}
|
||||
|
||||
func (r *Joiner) makeJoinEvent(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformJoinRequest,
|
||||
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
|
||||
}
|
||||
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
|
||||
}
|
||||
if strings.HasPrefix(req.RoomIDOrAlias, "!") {
|
||||
return r.performJoinRoomByIDCryptoIDs(ctx, req)
|
||||
}
|
||||
if strings.HasPrefix(req.RoomIDOrAlias, "#") {
|
||||
return r.performJoinRoomByAliasCryptoIDs(ctx, req)
|
||||
}
|
||||
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
|
||||
}
|
||||
|
||||
func (r *Joiner) performJoinRoomByAliasCryptoIDs(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformJoinRequest,
|
||||
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||
// Get the domain part of the room alias.
|
||||
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
|
||||
if err != nil {
|
||||
return nil, "", "", "", fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias)
|
||||
}
|
||||
req.ServerNames = append(req.ServerNames, domain)
|
||||
|
||||
// Check if this alias matches our own server configuration. If it
|
||||
// doesn't then we'll need to try a federated join.
|
||||
var roomID string
|
||||
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||
// The alias isn't owned by us, so we will need to try joining using
|
||||
// a remote server.
|
||||
dirReq := fsAPI.PerformDirectoryLookupRequest{
|
||||
RoomAlias: req.RoomIDOrAlias, // the room alias to lookup
|
||||
ServerName: domain, // the server to ask
|
||||
}
|
||||
dirRes := fsAPI.PerformDirectoryLookupResponse{}
|
||||
err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
|
||||
return nil, "", "", "", fmt.Errorf("looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
|
||||
}
|
||||
roomID = dirRes.RoomID
|
||||
req.ServerNames = append(req.ServerNames, dirRes.ServerNames...)
|
||||
} else {
|
||||
var getRoomReq = rsAPI.GetRoomIDForAliasRequest{
|
||||
Alias: req.RoomIDOrAlias,
|
||||
IncludeAppservices: true,
|
||||
}
|
||||
var getRoomRes = rsAPI.GetRoomIDForAliasResponse{}
|
||||
// Otherwise, look up if we know this room alias locally.
|
||||
err = r.RSAPI.GetRoomIDForAlias(ctx, &getRoomReq, &getRoomRes)
|
||||
if err != nil {
|
||||
return nil, "", "", "", fmt.Errorf("lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
|
||||
}
|
||||
roomID = getRoomRes.RoomID
|
||||
}
|
||||
|
||||
// If the room ID is empty then we failed to look up the alias.
|
||||
if roomID == "" {
|
||||
return nil, "", "", "", fmt.Errorf("alias %q not found", req.RoomIDOrAlias)
|
||||
}
|
||||
|
||||
// If we do, then pluck out the room ID and continue the join.
|
||||
req.RoomIDOrAlias = roomID
|
||||
return r.performJoinRoomByIDCryptoIDs(ctx, req)
|
||||
}
|
||||
|
||||
// TODO: Break this function up a bit & move to GMSL
|
||||
// nolint:gocyclo
|
||||
func (r *Joiner) performJoinRoomByIDCryptoIDs(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformJoinRequest,
|
||||
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||
// The original client request ?server_name=... may include this HS so filter that out so we
|
||||
// don't attempt to make_join with ourselves
|
||||
for i := 0; i < len(req.ServerNames); i++ {
|
||||
if r.Cfg.Matrix.IsLocalServerName(req.ServerNames[i]) {
|
||||
// delete this entry
|
||||
req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...)
|
||||
i--
|
||||
}
|
||||
}
|
||||
|
||||
// Get the domain part of the room ID.
|
||||
roomID, err := spec.NewRoomID(req.RoomIDOrAlias)
|
||||
if err != nil {
|
||||
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
|
||||
}
|
||||
|
||||
// If the server name in the room ID isn't ours then it's a
|
||||
// possible candidate for finding the room via federation. Add
|
||||
// it to the list of servers to try.
|
||||
if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
|
||||
req.ServerNames = append(req.ServerNames, roomID.Domain())
|
||||
}
|
||||
|
||||
// Force a federated join if we aren't in the room and we've been
|
||||
// given some server names to try joining by.
|
||||
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
|
||||
RoomID: req.RoomIDOrAlias,
|
||||
}
|
||||
inRoomRes := &rsAPI.QueryServerJoinedToRoomResponse{}
|
||||
if err = r.Queryer.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
|
||||
return nil, "", "", "", fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
|
||||
}
|
||||
serverInRoom := inRoomRes.IsInRoom
|
||||
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
|
||||
|
||||
userID, err := spec.NewUserID(req.UserID, true)
|
||||
if err != nil {
|
||||
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
||||
}
|
||||
|
||||
//TODO: CryptoIDs - what is provided & calculated senderIDs don't match?
|
||||
|
||||
// Look up the room NID for the supplied room ID.
|
||||
var senderID spec.SenderID
|
||||
checkInvitePending := false
|
||||
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
|
||||
if err == nil && info != nil {
|
||||
switch info.RoomVersion {
|
||||
case gomatrixserverlib.RoomVersionCryptoIDs:
|
||||
senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||
if queryErr == nil {
|
||||
checkInvitePending = true
|
||||
}
|
||||
if senderIDPtr == nil {
|
||||
senderID = req.SenderID
|
||||
} else {
|
||||
senderID = *senderIDPtr
|
||||
}
|
||||
default:
|
||||
checkInvitePending = true
|
||||
senderID = spec.SenderID(userID.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Force a federated join if we're dealing with a pending invite
|
||||
// and we aren't in the room.
|
||||
if checkInvitePending {
|
||||
isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
|
||||
if inviteErr == nil && !serverInRoom && isInvitePending {
|
||||
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
|
||||
if queryErr != nil {
|
||||
return nil, "", "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
|
||||
}
|
||||
|
||||
// If we were invited by someone from another server then we can
|
||||
// assume they are in the room so we can join via them.
|
||||
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
|
||||
req.ServerNames = append(req.ServerNames, inviter.Domain())
|
||||
forceFederatedJoin = true
|
||||
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
||||
// only set unsigned if we've got a content.membership, which we _should_
|
||||
if memberEvent.Get("content.membership").Exists() {
|
||||
req.Unsigned = map[string]interface{}{
|
||||
"prev_sender": memberEvent.Get("sender").Str,
|
||||
"prev_content": map[string]interface{}{
|
||||
"is_direct": memberEvent.Get("content.is_direct").Bool(),
|
||||
"membership": memberEvent.Get("content.membership").Str,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
|
||||
if req.IsGuest {
|
||||
var guestAccessEvent *types.HeaderedEvent
|
||||
guestAccess := "forbidden"
|
||||
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, spec.MRoomGuestAccess, "")
|
||||
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
|
||||
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
|
||||
}
|
||||
if guestAccessEvent != nil {
|
||||
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
|
||||
}
|
||||
|
||||
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
|
||||
// is present on the room and has the guest_access value can_join.
|
||||
if guestAccess != "can_join" {
|
||||
return nil, "", "", "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("guest access is forbidden")}
|
||||
}
|
||||
}
|
||||
|
||||
// If we should do a forced federated join then do that.
|
||||
if forceFederatedJoin {
|
||||
joinEvent, version, serverName, federatedJoinErr := r.performFederatedMakeJoinByIDCryptoIDs(ctx, req)
|
||||
return joinEvent, req.RoomIDOrAlias, version, serverName, federatedJoinErr
|
||||
}
|
||||
|
||||
// Try to construct an actual join event from the template.
|
||||
// If this succeeds then it is a sign that the room already exists
|
||||
// locally on the homeserver.
|
||||
// TODO: Check what happens if the room exists on the server
|
||||
// but everyone has since left. I suspect it does the wrong thing.
|
||||
|
||||
var buildRes rsAPI.QueryLatestEventsAndStateResponse
|
||||
identity := r.Cfg.Matrix.SigningIdentity
|
||||
|
||||
// at this point we know we have an existing room
|
||||
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
mapping := &gomatrixserverlib.MXIDMapping{
|
||||
UserRoomKey: senderID,
|
||||
UserID: userID.String(),
|
||||
}
|
||||
|
||||
// Sign the mapping with the server identity
|
||||
if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil {
|
||||
return nil, "", "", "", err
|
||||
}
|
||||
req.Content["mxid_mapping"] = mapping
|
||||
|
||||
// sign the event with the pseudo ID key
|
||||
identity = fclient.SigningIdentity{
|
||||
ServerName: spec.ServerName(senderID),
|
||||
KeyID: "ed25519:1",
|
||||
PrivateKey: nil,
|
||||
}
|
||||
}
|
||||
|
||||
senderIDString := string(senderID)
|
||||
|
||||
// Prepare the template for the join event.
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
Type: spec.MRoomMember,
|
||||
SenderID: senderIDString,
|
||||
StateKey: &senderIDString,
|
||||
RoomID: req.RoomIDOrAlias,
|
||||
Redacts: "",
|
||||
}
|
||||
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
||||
return nil, "", "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
}
|
||||
|
||||
// It is possible for the request to include some "content" for the
|
||||
// event. We'll always overwrite the "membership" key, but the rest,
|
||||
// like "display_name" or "avatar_url", will be kept if supplied.
|
||||
if req.Content == nil {
|
||||
req.Content = map[string]interface{}{}
|
||||
}
|
||||
req.Content["membership"] = spec.Join
|
||||
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
|
||||
return nil, "", "", "", aerr
|
||||
} else if authorisedVia != "" {
|
||||
req.Content["join_authorised_via_users_server"] = authorisedVia
|
||||
}
|
||||
if err = proto.SetContent(req.Content); err != nil {
|
||||
return nil, "", "", "", fmt.Errorf("eb.SetContent: %w", err)
|
||||
}
|
||||
joinEvent, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes)
|
||||
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
|
||||
case eventutil.ErrRoomNoExists:
|
||||
// The room doesn't exist locally. If the room ID looks like it should
|
||||
// be ours then this probably means that we've nuked our database at
|
||||
// some point.
|
||||
if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
|
||||
// If there are no more server names to try then give up here.
|
||||
// Otherwise we'll try a federated join as normal, since it's quite
|
||||
// possible that the room still exists on other servers.
|
||||
if len(req.ServerNames) == 0 {
|
||||
return nil, "", "", "", eventutil.ErrRoomNoExists{}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform a federated room join.
|
||||
joinEvent, version, serverName, federatedJoinErr := r.performFederatedMakeJoinByIDCryptoIDs(ctx, req)
|
||||
return joinEvent, req.RoomIDOrAlias, version, serverName, federatedJoinErr
|
||||
|
||||
default:
|
||||
// Something else went wrong.
|
||||
return nil, "", "", "", fmt.Errorf("error joining local room: %q", err)
|
||||
}
|
||||
|
||||
// By this point, if req.RoomIDOrAlias contained an alias, then
|
||||
// it will have been overwritten with a room ID by performJoinRoomByAlias.
|
||||
// We should now include this in the response so that the CS API can
|
||||
// return the right room ID.
|
||||
return joinEvent, req.RoomIDOrAlias, inRoomRes.RoomVersion, userID.Domain(), nil
|
||||
}
|
||||
|
||||
func (r *Joiner) performFederatedMakeJoinByIDCryptoIDs(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformJoinRequest,
|
||||
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||
// Try joining by all of the supplied server names.
|
||||
fedReq := fsAPI.PerformJoinRequest{
|
||||
RoomID: req.RoomIDOrAlias, // the room ID to try and join
|
||||
UserID: req.UserID, // the user ID joining the room
|
||||
ServerNames: req.ServerNames, // the server to try joining with
|
||||
Content: req.Content, // the membership event content
|
||||
Unsigned: req.Unsigned, // the unsigned event content, if any
|
||||
}
|
||||
joinEvent, version, serverName, err := r.FSAPI.PerformMakeJoin(ctx, &fedReq)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return joinEvent, version, serverName, nil
|
||||
}
|
||||
|
||||
func (r *Joiner) performFederatedJoinRoomByID(
|
||||
ctx context.Context,
|
||||
req *rsAPI.PerformJoinRequest,
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -52,9 +53,10 @@ func (r *Leaver) PerformLeave(
|
|||
ctx context.Context,
|
||||
req *api.PerformLeaveRequest,
|
||||
res *api.PerformLeaveResponse,
|
||||
) ([]api.OutputEvent, error) {
|
||||
cryptoIDs bool,
|
||||
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
|
||||
if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
|
||||
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
|
||||
return nil, nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
|
||||
}
|
||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||
"room_id": req.RoomID,
|
||||
|
@ -62,15 +64,15 @@ func (r *Leaver) PerformLeave(
|
|||
})
|
||||
logger.Info("User requested to leave join")
|
||||
if strings.HasPrefix(req.RoomID, "!") {
|
||||
output, err := r.performLeaveRoomByID(context.Background(), req, res)
|
||||
output, event, err := r.performLeaveRoomByID(context.Background(), req, res, cryptoIDs)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to leave room")
|
||||
} else {
|
||||
logger.Info("User left room successfully")
|
||||
}
|
||||
return output, err
|
||||
return output, event, err
|
||||
}
|
||||
return nil, fmt.Errorf("room ID %q is invalid", req.RoomID)
|
||||
return nil, nil, fmt.Errorf("room ID %q is invalid", req.RoomID)
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
|
@ -78,14 +80,15 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
ctx context.Context,
|
||||
req *api.PerformLeaveRequest,
|
||||
res *api.PerformLeaveResponse, // nolint:unparam
|
||||
) ([]api.OutputEvent, error) {
|
||||
cryptoIDs bool,
|
||||
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
|
||||
roomID, err := spec.NewRoomID(req.RoomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
|
||||
if err != nil || leaver == nil {
|
||||
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
||||
return nil, nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
||||
}
|
||||
|
||||
// If there's an invite outstanding for the room then respond to
|
||||
|
@ -94,7 +97,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
if err == nil && isInvitePending {
|
||||
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
|
||||
if serr != nil {
|
||||
return nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr)
|
||||
return nil, nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr)
|
||||
}
|
||||
|
||||
var domain spec.ServerName
|
||||
|
@ -107,7 +110,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
domain = sender.Domain()
|
||||
}
|
||||
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||
return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver)
|
||||
return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver, cryptoIDs)
|
||||
}
|
||||
// check that this is not a "server notice room"
|
||||
accData := &userapi.QueryAccountDataResponse{}
|
||||
|
@ -116,7 +119,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
RoomID: req.RoomID,
|
||||
DataType: "m.tag",
|
||||
}, accData); err != nil {
|
||||
return nil, fmt.Errorf("unable to query account data: %w", err)
|
||||
return nil, nil, fmt.Errorf("unable to query account data: %w", err)
|
||||
}
|
||||
|
||||
if roomData, ok := accData.RoomAccountData[req.RoomID]; ok {
|
||||
|
@ -124,13 +127,13 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
if ok {
|
||||
tags := gomatrix.TagContent{}
|
||||
if err = json.Unmarshal(tagData, &tags); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal tag content")
|
||||
return nil, nil, fmt.Errorf("unable to unmarshal tag content")
|
||||
}
|
||||
if _, ok = tags.Tags["m.server_notice"]; ok {
|
||||
// mimic the returned values from Synapse
|
||||
res.Message = "You cannot reject this invite"
|
||||
res.Code = 403
|
||||
return nil, spec.LeaveServerNoticeError()
|
||||
return nil, nil, spec.LeaveServerNoticeError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -149,22 +152,22 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
}
|
||||
latestRes := api.QueryLatestEventsAndStateResponse{}
|
||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if !latestRes.RoomExists {
|
||||
return nil, fmt.Errorf("room %q does not exist", req.RoomID)
|
||||
return nil, nil, fmt.Errorf("room %q does not exist", req.RoomID)
|
||||
}
|
||||
|
||||
// Now let's see if the user is in the room.
|
||||
if len(latestRes.StateEvents) == 0 {
|
||||
return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
|
||||
return nil, nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
|
||||
}
|
||||
membership, err := latestRes.StateEvents[0].Membership()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting membership: %w", err)
|
||||
return nil, nil, fmt.Errorf("error getting membership: %w", err)
|
||||
}
|
||||
if membership != spec.Join && membership != spec.Invite {
|
||||
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
|
||||
return nil, nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
|
||||
}
|
||||
|
||||
// Prepare the template for the leave event.
|
||||
|
@ -177,10 +180,10 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
Redacts: "",
|
||||
}
|
||||
if err = proto.SetContent(map[string]interface{}{"membership": "leave"}); err != nil {
|
||||
return nil, fmt.Errorf("eb.SetContent: %w", err)
|
||||
return nil, nil, fmt.Errorf("eb.SetContent: %w", err)
|
||||
}
|
||||
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
||||
return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
return nil, nil, fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
}
|
||||
|
||||
// We know that the user is in the room at this point so let's build
|
||||
|
@ -190,19 +193,29 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
|
||||
validRoomID, err := spec.NewRoomID(req.RoomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var buildRes rsAPI.QueryLatestEventsAndStateResponse
|
||||
identity, err := r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver)
|
||||
var identity fclient.SigningIdentity
|
||||
if !cryptoIDs {
|
||||
identity, err = r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SigningIdentityFor: %w", err)
|
||||
return nil, nil, fmt.Errorf("SigningIdentityFor: %w", err)
|
||||
}
|
||||
} else {
|
||||
identity = fclient.SigningIdentity{
|
||||
ServerName: spec.ServerName(*leaver),
|
||||
KeyID: "ed25519:1",
|
||||
PrivateKey: nil,
|
||||
}
|
||||
}
|
||||
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err)
|
||||
return nil, nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err)
|
||||
}
|
||||
|
||||
if !cryptoIDs {
|
||||
// Give our leave event to the roomserver input stream. The
|
||||
// roomserver will process the membership change and notify
|
||||
// downstream automatically.
|
||||
|
@ -219,10 +232,11 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
inputRes := api.InputRoomEventsResponse{}
|
||||
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
|
||||
if err = inputRes.Err(); err != nil {
|
||||
return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
|
||||
return nil, nil, fmt.Errorf("r.InputRoomEvents: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, event, nil
|
||||
}
|
||||
|
||||
func (r *Leaver) performFederatedRejectInvite(
|
||||
|
@ -231,7 +245,8 @@ func (r *Leaver) performFederatedRejectInvite(
|
|||
res *api.PerformLeaveResponse, // nolint:unparam
|
||||
inviteDomain spec.ServerName, eventID string,
|
||||
leaver spec.SenderID,
|
||||
) ([]api.OutputEvent, error) {
|
||||
cryptoIDs bool,
|
||||
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
|
||||
// Ask the federation sender to perform a federated leave for us.
|
||||
leaveReq := fsAPI.PerformLeaveRequest{
|
||||
RoomID: req.RoomID,
|
||||
|
@ -239,7 +254,7 @@ func (r *Leaver) performFederatedRejectInvite(
|
|||
ServerNames: []spec.ServerName{inviteDomain},
|
||||
}
|
||||
leaveRes := fsAPI.PerformLeaveResponse{}
|
||||
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
|
||||
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes, cryptoIDs); err != nil {
|
||||
// failures in PerformLeave should NEVER stop us from telling other components like the
|
||||
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
|
||||
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
|
||||
|
@ -279,5 +294,5 @@ func (r *Leaver) performFederatedRejectInvite(
|
|||
TargetSenderID: leaver,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}, nil, nil
|
||||
}
|
||||
|
|
|
@ -1044,7 +1044,7 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
|
|||
}
|
||||
|
||||
switch version {
|
||||
case gomatrixserverlib.RoomVersionPseudoIDs:
|
||||
case gomatrixserverlib.RoomVersionPseudoIDs, gomatrixserverlib.RoomVersionCryptoIDs:
|
||||
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -16,7 +16,8 @@ type RoomServer struct {
|
|||
}
|
||||
|
||||
func (c *RoomServer) Defaults(opts DefaultOpts) {
|
||||
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
|
||||
//c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
|
||||
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionCryptoIDs
|
||||
if opts.Generate {
|
||||
if !opts.SingleDatabase {
|
||||
c.Database.ConnectionString = "file:roomserver.db"
|
||||
|
|
|
@ -46,6 +46,16 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
|
|||
return nil
|
||||
}
|
||||
|
||||
// OTCryptoIDCounts adds one-time pseudoID counts to the /sync response
|
||||
func OTCryptoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
|
||||
count, err := keyAPI.QueryOneTimeCryptoIDs(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OTCryptoIDsCount = count.KeyCount
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||
// be already filled in with join/leave information.
|
||||
|
|
|
@ -50,7 +50,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
|
|||
}
|
||||
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
|
||||
return nil
|
||||
|
||||
}
|
||||
func (a *mockKeyAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimeCryptoIDsCount{}, nil
|
||||
}
|
||||
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
|
||||
return nil
|
||||
|
|
|
@ -38,7 +38,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
|
|||
}
|
||||
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
||||
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
|
||||
return from
|
||||
}
|
||||
err = internal.OTCryptoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
|
||||
return from
|
||||
}
|
||||
|
||||
|
|
|
@ -280,6 +280,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
if err != nil && err != context.Canceled {
|
||||
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
|
||||
}
|
||||
err = internal.OTCryptoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
|
||||
if err != nil && err != context.Canceled {
|
||||
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
|
|
|
@ -112,6 +112,10 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *syncUserAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimeCryptoIDsCount{}, nil
|
||||
}
|
||||
|
||||
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -153,7 +153,7 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDFor
|
|||
// TODO: Set Signatures & Hashes fields
|
||||
}
|
||||
|
||||
if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if format != FormatSyncFederation && (se.Version() == gomatrixserverlib.RoomVersionPseudoIDs || se.Version() == gomatrixserverlib.RoomVersionCryptoIDs) {
|
||||
err := updatePseudoIDs(&ce, se, userIDForSender, format)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -304,7 +304,7 @@ func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomS
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation {
|
||||
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != FormatSyncFederation {
|
||||
for i, ev := range inviteStateEvents {
|
||||
userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
|
||||
if userIDErr != nil {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
@ -365,6 +366,7 @@ type Response struct {
|
|||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
OTCryptoIDsCount map[string]int `json:"one_time_cryptoids_count,omitempty"`
|
||||
}
|
||||
|
||||
func (r Response) MarshalJSON() ([]byte, error) {
|
||||
|
@ -427,6 +429,7 @@ func NewResponse() *Response {
|
|||
res.DeviceLists = &DeviceLists{}
|
||||
res.ToDevice = &ToDeviceResponse{}
|
||||
res.DeviceListsOTKCount = map[string]int{}
|
||||
res.OTCryptoIDsCount = map[string]int{}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
@ -530,6 +533,7 @@ type InviteResponse struct {
|
|||
InviteState struct {
|
||||
Events []json.RawMessage `json:"events"`
|
||||
} `json:"invite_state"`
|
||||
OneTimeCryptoID string `json:"one_time_cryptoid,omitempty"`
|
||||
}
|
||||
|
||||
// NewInviteResponse creates an empty response with initialised arrays.
|
||||
|
@ -537,11 +541,17 @@ func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *t
|
|||
res := InviteResponse{}
|
||||
res.InviteState.Events = []json.RawMessage{}
|
||||
|
||||
logrus.Infof("Room version: %s", event.Version())
|
||||
if event.Version() == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
logrus.Infof("Setting invite cryptoID to %s", *event.PDU.StateKey())
|
||||
res.OneTimeCryptoID = *event.PDU.StateKey()
|
||||
}
|
||||
|
||||
// First see if there's invite_room_state in the unsigned key of the invite.
|
||||
// If there is then unmarshal it into the response. This will contain the
|
||||
// partial room state such as join rules, room name etc.
|
||||
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
|
||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation {
|
||||
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != synctypes.FormatSyncFederation {
|
||||
updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, inviteRoomState, event.PDU, event.RoomID(), eventFormat)
|
||||
|
|
|
@ -51,6 +51,7 @@ type AppserviceUserAPI interface {
|
|||
type RoomserverUserAPI interface {
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||
ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||
}
|
||||
|
||||
// api functions required by the media api
|
||||
|
@ -669,6 +670,7 @@ type UploadDeviceKeysAPI interface {
|
|||
type SyncKeyAPI interface {
|
||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
||||
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
||||
QueryOneTimeCryptoIDs(ctx context.Context, userID string) (OneTimeCryptoIDsCount, *KeyError)
|
||||
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
|
@ -772,12 +774,25 @@ type OneTimeKeys struct {
|
|||
KeyJSON map[string]json.RawMessage
|
||||
}
|
||||
|
||||
type OneTimeCryptoIDs struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// A map of algorithm:key_id => key JSON
|
||||
KeyJSON map[string]json.RawMessage
|
||||
}
|
||||
|
||||
// Split a key in KeyJSON into algorithm and key ID
|
||||
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
segments := strings.Split(keyIDWithAlgo, ":")
|
||||
return segments[0], segments[1]
|
||||
}
|
||||
|
||||
// Split a key in KeyJSON into algorithm and key ID
|
||||
func (k *OneTimeCryptoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
segments := strings.Split(keyIDWithAlgo, ":")
|
||||
return segments[0], segments[1]
|
||||
}
|
||||
|
||||
// OneTimeKeysCount represents the counts of one-time keys for a single device
|
||||
type OneTimeKeysCount struct {
|
||||
// The user who owns this device
|
||||
|
@ -792,12 +807,23 @@ type OneTimeKeysCount struct {
|
|||
KeyCount map[string]int
|
||||
}
|
||||
|
||||
type OneTimeCryptoIDsCount struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// algorithm to count e.g:
|
||||
// {
|
||||
// "pseudoid_curve25519": 10,
|
||||
// }
|
||||
KeyCount map[string]int
|
||||
}
|
||||
|
||||
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
||||
type PerformUploadKeysRequest struct {
|
||||
UserID string // Required - User performing the request
|
||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
OneTimeCryptoIDs []OneTimeCryptoIDs
|
||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||
// the display name for their respective device, and NOT to modify the keys. The key
|
||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||
|
@ -812,6 +838,7 @@ type PerformUploadKeysResponse struct {
|
|||
// A map of user_id -> device_id -> Error for tracking failures.
|
||||
KeyErrors map[string]map[string]*KeyError
|
||||
OneTimeKeyCounts []OneTimeKeysCount
|
||||
OneTimeCryptoIDCounts []OneTimeCryptoIDsCount
|
||||
}
|
||||
|
||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||
|
|
|
@ -647,7 +647,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
|
|||
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
||||
user := ""
|
||||
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
if err == nil {
|
||||
if err == nil && sender != nil {
|
||||
user = sender.String()
|
||||
}
|
||||
if user == mem.UserID {
|
||||
|
|
|
@ -17,9 +17,11 @@ package internal
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -55,11 +57,21 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
}
|
||||
if len(req.OneTimeCryptoIDs) > 0 {
|
||||
a.uploadOneTimeCryptoIDs(ctx, req, res)
|
||||
}
|
||||
logrus.Infof("One time cryptoIDs count before: %v", res.OneTimeCryptoIDCounts)
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||
otpIDs, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimeCryptoIDCounts = []api.OneTimeCryptoIDsCount{*otpIDs}
|
||||
logrus.Infof("One time cryptoIDs count after: %v", res.OneTimeCryptoIDCounts)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -181,6 +193,17 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (api.OneTimeCryptoIDsCount, *api.KeyError) {
|
||||
count, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, userID)
|
||||
if err != nil {
|
||||
return api.OneTimeCryptoIDsCount{}, &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query OTID counts: %s", err),
|
||||
}
|
||||
}
|
||||
return *count, nil
|
||||
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
||||
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||
if err != nil {
|
||||
|
@ -773,6 +796,98 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadOneTimeCryptoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
if req.UserID == "" {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "user ID missing",
|
||||
}
|
||||
}
|
||||
if len(req.OneTimeCryptoIDs) == 0 {
|
||||
counts, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.KeyDatabase.OneTimeCryptoIDsCount: %s", err),
|
||||
}
|
||||
}
|
||||
if counts != nil {
|
||||
logrus.Infof("Uploading one-time cryptoIDs: early result count: %v", *counts)
|
||||
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, key := range req.OneTimeCryptoIDs {
|
||||
// grab existing keys based on (user/algorithm/key ID)
|
||||
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||
i := 0
|
||||
for keyIDWithAlgo := range key.KeyJSON {
|
||||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
existingKeys, err := a.KeyDatabase.ExistingOneTimeCryptoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: "failed to query existing one-time cryptoIDs: " + err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
for keyIDWithAlgo := range existingKeys {
|
||||
// if keys exist and the JSON doesn't match, error out as the key already exists
|
||||
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time cryptoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
// store one-time keys
|
||||
counts, err := a.KeyDatabase.StoreOneTimeCryptoIDs(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store one-time cryptoIDs: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// collect counts
|
||||
logrus.Infof("Uploading one-time cryptoIDs: result count: %v", *counts)
|
||||
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
|
||||
}
|
||||
}
|
||||
|
||||
type Ed25519Key struct {
|
||||
Key spec.Base64Bytes `json:"key"`
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||
cryptoID, err := a.KeyDatabase.ClaimOneTimeCryptoID(ctx, userID, "ed25519")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
logrus.Infof("Claimed one time cryptoID: %s", cryptoID)
|
||||
|
||||
if cryptoID != nil {
|
||||
for key, value := range cryptoID.KeyJSON {
|
||||
keyParts := strings.Split(key, ":")
|
||||
if keyParts[0] == "ed25519" {
|
||||
var key_bytes Ed25519Key
|
||||
err := json.Unmarshal(value, &key_bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
length := len(key_bytes.Key)
|
||||
if length != ed25519.PublicKeySize {
|
||||
return "", fmt.Errorf("Invalid ed25519 public key, %d is the wrong size", length)
|
||||
}
|
||||
// TODO: cryptoIDs - store senderID for this user here?
|
||||
return spec.SenderID(key_bytes.Key.Encode()), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed claiming a valid one time cryptoID for this user: %s", userID.String())
|
||||
}
|
||||
|
||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||
// if we only want to update the display names, we can skip the checks below
|
||||
if onlyUpdateDisplayName {
|
||||
|
|
|
@ -175,6 +175,11 @@ type KeyDatabase interface {
|
|||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
|
||||
ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
|
||||
OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
|
||||
ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
|
|
191
userapi/storage/postgres/one_time_cryptoids_table.go
Normal file
191
userapi/storage/postgres/one_time_cryptoids_table.go
Normal file
|
@ -0,0 +1,191 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeCryptoIDsSchema = `
|
||||
-- Stores one-time cryptoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
|
||||
user_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 3-uple of user/key/algorithm.
|
||||
CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE (user_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
|
||||
`
|
||||
|
||||
const upsertCryptoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" +
|
||||
" DO UPDATE SET key_json = $5"
|
||||
|
||||
const selectOneTimeCryptoIDsSQL = "" +
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
|
||||
|
||||
const selectCryptoIDsCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeCryptoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
|
||||
const selectCryptoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
|
||||
const deleteOneTimeCryptoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
|
||||
|
||||
type oneTimeCryptoIDsStatements struct {
|
||||
db *sql.DB
|
||||
upsertCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsCountStmt *sql.Stmt
|
||||
selectCryptoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
|
||||
s := &oneTimeCryptoIDsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeCryptoIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
|
||||
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
|
||||
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
var (
|
||||
algorithmWithID string
|
||||
keyJSONStr string
|
||||
)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
|
||||
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
|
@ -149,6 +149,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otpid, err := NewPostgresOneTimeCryptoIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewPostgresDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -172,6 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
OneTimeCryptoIDsTable: otpid,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -65,6 +65,7 @@ type Database struct {
|
|||
|
||||
type KeyDatabase struct {
|
||||
OneTimeKeysTable tables.OneTimeKeys
|
||||
OneTimeCryptoIDsTable tables.OneTimeCryptoIDs
|
||||
DeviceKeysTable tables.DeviceKeys
|
||||
KeyChangesTable tables.KeyChanges
|
||||
StaleDeviceListsTable tables.StaleDeviceLists
|
||||
|
@ -945,6 +946,40 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
|
|||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
return d.OneTimeCryptoIDsTable.SelectOneTimeCryptoIDs(ctx, userID, keyIDsWithAlgorithms)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (counts *api.OneTimeCryptoIDsCount, err error) {
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
counts, err = d.OneTimeCryptoIDsTable.InsertOneTimeCryptoIDs(ctx, txn, keys)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||
return d.OneTimeCryptoIDsTable.CountOneTimeCryptoIDs(ctx, userID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) {
|
||||
var result *api.OneTimeCryptoIDs
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
keyJSON, err := d.OneTimeCryptoIDsTable.SelectAndDeleteOneTimeCryptoID(ctx, txn, userID.String(), algorithm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if keyJSON != nil {
|
||||
result = &api.OneTimeCryptoIDs{
|
||||
UserID: userID.String(),
|
||||
KeyJSON: keyJSON,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
|
208
userapi/storage/sqlite3/one_time_cryptoids_table.go
Normal file
208
userapi/storage/sqlite3/one_time_cryptoids_table.go
Normal file
|
@ -0,0 +1,208 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var oneTimeCryptoIDsSchema = `
|
||||
-- Stores one-time cryptoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
|
||||
user_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 3-uple of user/key/algorithm.
|
||||
UNIQUE (user_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
|
||||
`
|
||||
|
||||
const upsertCryptoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT (user_id, key_id, algorithm)" +
|
||||
" DO UPDATE SET key_json = $5"
|
||||
|
||||
const selectOneTimeCryptoIDsSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1"
|
||||
|
||||
const selectCryptoIDsCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeCryptoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
|
||||
const selectCryptoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
|
||||
const deleteOneTimeCryptoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
|
||||
|
||||
type oneTimeCryptoIDsStatements struct {
|
||||
db *sql.DB
|
||||
upsertCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsCountStmt *sql.Stmt
|
||||
selectCryptoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
|
||||
s := &oneTimeCryptoIDsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeCryptoIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
|
||||
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
|
||||
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
wantSet[ka] = true
|
||||
}
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
for rows.Next() {
|
||||
var keyID string
|
||||
var algorithm string
|
||||
var keyJSONStr string
|
||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyIDWithAlgo := algorithm + ":" + keyID
|
||||
if wantSet[keyIDWithAlgo] {
|
||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(
|
||||
ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs,
|
||||
) (*api.OneTimeCryptoIDsCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
|
||||
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logrus.Warnf("No rows found for one time cryptoIDs")
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyJSON == "" {
|
||||
logrus.Warnf("Empty key JSON for one time cryptoIDs")
|
||||
return nil, nil
|
||||
}
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
|
@ -146,6 +146,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otpid, err := NewSqliteOneTimeCryptoIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewSqliteDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -169,6 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
OneTimeCryptoIDsTable: otpid,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOneTimeCryptoIDs(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
userID := "@alice:localhost"
|
||||
otk := api.OneTimeCryptoIDs{
|
||||
UserID: userID,
|
||||
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||
}
|
||||
|
||||
// Add a one time pseudoID to the DB
|
||||
_, err := db.StoreOneTimeCryptoIDs(ctx, otk)
|
||||
MustNotError(t, err)
|
||||
|
||||
// Check the count of one time pseudoIDs is correct
|
||||
count, err := db.OneTimeCryptoIDsCount(ctx, userID)
|
||||
MustNotError(t, err)
|
||||
if count.KeyCount["pseudoid_curve25519"] != 1 {
|
||||
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
|
||||
}
|
||||
|
||||
// Check the actual pseudoid contents are correct
|
||||
keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
|
||||
MustNotError(t, err)
|
||||
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
|
||||
MustNotError(t, err)
|
||||
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
|
||||
t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -168,6 +168,14 @@ type OneTimeKeys interface {
|
|||
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
}
|
||||
|
||||
type OneTimeCryptoIDs interface {
|
||||
SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
|
||||
InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
|
||||
SelectAndDeleteOneTimeCryptoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
|
||||
DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
||||
type DeviceKeys interface {
|
||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
|
|
Loading…
Reference in a new issue