Compare commits
16 commits
main
...
crypto-ids
Author | SHA1 | Date | |
---|---|---|---|
2acc285f38 | |||
59753aba98 | |||
cbd547c828 | |||
1dc00f5fd2 | |||
b45e72830e | |||
3cbccb9ed7 | |||
5930a04044 | |||
227493cc5d | |||
7f7ac0f4fe | |||
b7d320f8d1 | |||
60be1391bf | |||
038103ac7f | |||
29cd14baf5 | |||
4bfcf27106 | |||
f17de49c6b | |||
a5ba533cfb |
|
@ -48,6 +48,7 @@ type createRoomRequest struct {
|
||||||
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
|
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
|
||||||
PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"`
|
PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"`
|
||||||
IsDirect bool `json:"is_direct"`
|
IsDirect bool `json:"is_direct"`
|
||||||
|
CryptoID string `json:"cryptoid"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r createRoomRequest) Validate() *util.JSONResponse {
|
func (r createRoomRequest) Validate() *util.JSONResponse {
|
||||||
|
@ -107,12 +108,27 @@ type createRoomResponse struct {
|
||||||
RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec
|
RoomAlias string `json:"room_alias,omitempty"` // in synapse not spec
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type createRoomCryptoIDsResponse struct {
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
Version string `json:"room_version"`
|
||||||
|
PDUs []json.RawMessage `json:"pdus"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToProtoEvents(ctx context.Context, events []gomatrixserverlib.PDU, rsAPI roomserverAPI.ClientRoomserverAPI) []json.RawMessage {
|
||||||
|
result := make([]json.RawMessage, len(events))
|
||||||
|
for i, event := range events {
|
||||||
|
result[i] = json.RawMessage(event.JSON())
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// CreateRoom implements /createRoom
|
// CreateRoom implements /createRoom
|
||||||
func CreateRoom(
|
func CreateRoom(
|
||||||
req *http.Request, device *api.Device,
|
req *http.Request, device *api.Device,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var createRequest createRoomRequest
|
var createRequest createRoomRequest
|
||||||
resErr := httputil.UnmarshalJSONRequest(req, &createRequest)
|
resErr := httputil.UnmarshalJSONRequest(req, &createRequest)
|
||||||
|
@ -129,10 +145,9 @@ func CreateRoom(
|
||||||
JSON: spec.InvalidParam(err.Error()),
|
JSON: spec.InvalidParam(err.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime)
|
return createRoom(req.Context(), createRequest, device, cfg, profileAPI, rsAPI, asAPI, evTime, cryptoIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createRoom implements /createRoom
|
|
||||||
func createRoom(
|
func createRoom(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
createRequest createRoomRequest, device *api.Device,
|
createRequest createRoomRequest, device *api.Device,
|
||||||
|
@ -140,6 +155,7 @@ func createRoom(
|
||||||
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
asAPI appserviceAPI.AppServiceInternalAPI,
|
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
userID, err := spec.NewUserID(device.UserID, true)
|
userID, err := spec.NewUserID(device.UserID, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -225,6 +241,7 @@ func createRoom(
|
||||||
EventTime: evTime,
|
EventTime: evTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cryptoIDs {
|
||||||
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
|
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
|
||||||
if createRes != nil {
|
if createRes != nil {
|
||||||
return *createRes
|
return *createRes
|
||||||
|
@ -239,4 +256,27 @@ func createRoom(
|
||||||
Code: 200,
|
Code: 200,
|
||||||
JSON: response,
|
JSON: response,
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
req.SenderID = createRequest.CryptoID
|
||||||
|
createEvents, err := rsAPI.PerformCreateRoomCryptoIDs(ctx, *userID, *roomID, &req)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("MakeCreateRoomEvents failed")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{Err: err.Error()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response := createRoomCryptoIDsResponse{
|
||||||
|
RoomID: roomID.String(),
|
||||||
|
Version: string(roomVersion),
|
||||||
|
PDUs: ToProtoEvents(ctx, createEvents, rsAPI),
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: response,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ import (
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func JoinRoomByIDOrAlias(
|
func JoinRoomByIDOrAlias(
|
||||||
|
@ -144,3 +145,139 @@ func JoinRoomByIDOrAlias(
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type joinRoomCryptoIDsResponse struct {
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
Version string `json:"room_version"`
|
||||||
|
ViaServer string `json:"via_server"`
|
||||||
|
PDU json.RawMessage `json:"pdu"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func JoinRoomByIDOrAliasCryptoIDs(
|
||||||
|
req *http.Request,
|
||||||
|
device *api.Device,
|
||||||
|
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
|
profileAPI api.ClientUserAPI,
|
||||||
|
roomIDOrAlias string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
// Prepare to ask the roomserver to perform the room join.
|
||||||
|
joinReq := roomserverAPI.PerformJoinRequest{
|
||||||
|
RoomIDOrAlias: roomIDOrAlias,
|
||||||
|
UserID: device.UserID,
|
||||||
|
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||||
|
Content: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check to see if any ?server_name= query parameters were
|
||||||
|
// given in the request.
|
||||||
|
if serverNames, ok := req.URL.Query()["server_name"]; ok {
|
||||||
|
for _, serverName := range serverNames {
|
||||||
|
joinReq.ServerNames = append(
|
||||||
|
joinReq.ServerNames,
|
||||||
|
spec.ServerName(serverName),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If content was provided in the request then include that
|
||||||
|
// in the request. It'll get used as a part of the membership
|
||||||
|
// event content.
|
||||||
|
_ = httputil.UnmarshalJSONRequest(req, &joinReq.Content)
|
||||||
|
|
||||||
|
if senderid, ok := joinReq.Content["cryptoid"]; ok {
|
||||||
|
logrus.Errorf("CryptoID: %s", senderid.(string))
|
||||||
|
joinReq.SenderID = spec.SenderID(senderid.(string))
|
||||||
|
} else {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.Unknown("Missing cryptoid in request body"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(joinReq.Content, "cryptoid")
|
||||||
|
|
||||||
|
// Work out our localpart for the client profile request.
|
||||||
|
|
||||||
|
// Request our profile content to populate the request content with.
|
||||||
|
profile, err := profileAPI.QueryProfile(req.Context(), device.UserID)
|
||||||
|
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
joinReq.Content["displayname"] = profile.DisplayName
|
||||||
|
joinReq.Content["avatar_url"] = profile.AvatarURL
|
||||||
|
case appserviceAPI.ErrProfileNotExists:
|
||||||
|
util.GetLogger(req.Context()).Error("Unable to query user profile, no profile found.")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.Unknown("Unable to query user profile, no profile found."),
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ask the roomserver to perform the join.
|
||||||
|
done := make(chan util.JSONResponse, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
joinEvent, roomID, version, serverName, err := rsAPI.PerformJoinCryptoIDs(req.Context(), &joinReq)
|
||||||
|
var response util.JSONResponse
|
||||||
|
|
||||||
|
switch e := err.(type) {
|
||||||
|
case nil: // success case
|
||||||
|
response = util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: joinRoomCryptoIDsResponse{
|
||||||
|
RoomID: roomID,
|
||||||
|
Version: string(version),
|
||||||
|
ViaServer: string(serverName),
|
||||||
|
PDU: json.RawMessage(joinEvent.JSON()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case roomserverAPI.ErrInvalidID:
|
||||||
|
response = util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.Unknown(e.Error()),
|
||||||
|
}
|
||||||
|
case roomserverAPI.ErrNotAllowed:
|
||||||
|
jsonErr := spec.Forbidden(e.Error())
|
||||||
|
if device.AccountType == api.AccountTypeGuest {
|
||||||
|
jsonErr = spec.GuestAccessForbidden(e.Error())
|
||||||
|
}
|
||||||
|
response = util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonErr,
|
||||||
|
}
|
||||||
|
case *gomatrix.HTTPError: // this ensures we proxy responses over federation to the client
|
||||||
|
response = util.JSONResponse{
|
||||||
|
Code: e.Code,
|
||||||
|
JSON: json.RawMessage(e.Message),
|
||||||
|
}
|
||||||
|
case eventutil.ErrRoomNoExists:
|
||||||
|
response = util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: spec.NotFound(e.Error()),
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
response = util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
done <- response
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait either for the join to finish, or for us to hit a reasonable
|
||||||
|
// timeout, at which point we'll just return a 200 to placate clients.
|
||||||
|
timer := time.NewTimer(time.Second * 20)
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusRequestTimeout,
|
||||||
|
JSON: spec.Unknown("Failed creating join event with the remote server."),
|
||||||
|
}
|
||||||
|
case result := <-done:
|
||||||
|
// Stop and drain the timer
|
||||||
|
if !timer.Stop() {
|
||||||
|
<-timer.C
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
|
||||||
Preset: spec.PresetPublicChat,
|
Preset: spec.PresetPublicChat,
|
||||||
RoomAliasName: "alias",
|
RoomAliasName: "alias",
|
||||||
Invite: []string{bob.ID},
|
Invite: []string{bob.ID},
|
||||||
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now(), false)
|
||||||
crResp, ok := resp.JSON.(createRoomResponse)
|
crResp, ok := resp.JSON.(createRoomResponse)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
@ -81,7 +81,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
|
||||||
Visibility: "public",
|
Visibility: "public",
|
||||||
Preset: spec.PresetPublicChat,
|
Preset: spec.PresetPublicChat,
|
||||||
Invite: []string{charlie.ID},
|
Invite: []string{charlie.ID},
|
||||||
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
}, aliceDev, &cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now(), false)
|
||||||
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
|
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
|
|
@ -97,6 +97,92 @@ type queryKeysRequest struct {
|
||||||
DeviceKeys map[string][]string `json:"device_keys"`
|
DeviceKeys map[string][]string `json:"device_keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type uploadKeysCryptoIDsRequest struct {
|
||||||
|
DeviceKeys json.RawMessage `json:"device_keys"`
|
||||||
|
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
|
||||||
|
OneTimeCryptoIDs map[string]json.RawMessage `json:"one_time_cryptoids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
|
||||||
|
var r uploadKeysCryptoIDsRequest
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadReq := &api.PerformUploadKeysRequest{
|
||||||
|
DeviceID: device.ID,
|
||||||
|
UserID: device.UserID,
|
||||||
|
}
|
||||||
|
if r.DeviceKeys != nil {
|
||||||
|
uploadReq.DeviceKeys = []api.DeviceKeys{
|
||||||
|
{
|
||||||
|
DeviceID: device.ID,
|
||||||
|
UserID: device.UserID,
|
||||||
|
KeyJSON: r.DeviceKeys,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.OneTimeKeys != nil {
|
||||||
|
uploadReq.OneTimeKeys = []api.OneTimeKeys{
|
||||||
|
{
|
||||||
|
DeviceID: device.ID,
|
||||||
|
UserID: device.UserID,
|
||||||
|
KeyJSON: r.OneTimeKeys,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.OneTimeCryptoIDs != nil {
|
||||||
|
uploadReq.OneTimeCryptoIDs = []api.OneTimeCryptoIDs{
|
||||||
|
{
|
||||||
|
UserID: device.UserID,
|
||||||
|
KeyJSON: r.OneTimeCryptoIDs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
util.GetLogger(req.Context()).
|
||||||
|
WithField("device keys", r.DeviceKeys).
|
||||||
|
WithField("one-time keys", r.OneTimeKeys).
|
||||||
|
WithField("one-time cryptoids", r.OneTimeCryptoIDs).
|
||||||
|
Info("Uploading keys")
|
||||||
|
|
||||||
|
var uploadRes api.PerformUploadKeysResponse
|
||||||
|
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
if uploadRes.Error != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(uploadRes.KeyErrors) > 0 {
|
||||||
|
util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 400,
|
||||||
|
JSON: uploadRes.KeyErrors,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keyCount := make(map[string]int)
|
||||||
|
if len(uploadRes.OneTimeKeyCounts) > 0 {
|
||||||
|
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
|
||||||
|
}
|
||||||
|
cryptoIDCount := make(map[string]int)
|
||||||
|
if len(uploadRes.OneTimeCryptoIDCounts) > 0 {
|
||||||
|
cryptoIDCount = uploadRes.OneTimeCryptoIDCounts[0].KeyCount
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct {
|
||||||
|
OTKCounts interface{} `json:"one_time_key_counts"`
|
||||||
|
OTIDCounts interface{} `json:"one_time_cryptoid_counts"`
|
||||||
|
}{keyCount, cryptoIDCount},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *queryKeysRequest) GetTimeout() time.Duration {
|
func (r *queryKeysRequest) GetTimeout() time.Duration {
|
||||||
if r.Timeout == 0 {
|
if r.Timeout == 0 {
|
||||||
return 10 * time.Second
|
return 10 * time.Second
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -23,11 +24,16 @@ import (
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type leaveRoomCryptoIDsResponse struct {
|
||||||
|
PDU json.RawMessage `json:"pdu"`
|
||||||
|
}
|
||||||
|
|
||||||
func LeaveRoomByID(
|
func LeaveRoomByID(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
roomID string,
|
roomID string,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
userID, err := spec.NewUserID(device.UserID, true)
|
userID, err := spec.NewUserID(device.UserID, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -45,7 +51,8 @@ func LeaveRoomByID(
|
||||||
leaveRes := roomserverAPI.PerformLeaveResponse{}
|
leaveRes := roomserverAPI.PerformLeaveResponse{}
|
||||||
|
|
||||||
// Ask the roomserver to perform the leave.
|
// Ask the roomserver to perform the leave.
|
||||||
if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil {
|
leaveEvent, err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes, cryptoIDs)
|
||||||
|
if err != nil {
|
||||||
if leaveRes.Code != 0 {
|
if leaveRes.Code != 0 {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: leaveRes.Code,
|
Code: leaveRes.Code,
|
||||||
|
@ -58,8 +65,15 @@ func LeaveRoomByID(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cryptoIDs {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: leaveRoomCryptoIDsResponse{
|
||||||
|
PDU: json.RawMessage(leaveEvent.JSON()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: struct{}{},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package routing
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -39,10 +40,15 @@ import (
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type membershipCryptoIDsResponse struct {
|
||||||
|
PDU json.RawMessage `json:"pdu"`
|
||||||
|
}
|
||||||
|
|
||||||
func SendBan(
|
func SendBan(
|
||||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
body, evTime, reqErr := extractRequestData(req)
|
body, evTime, reqErr := extractRequestData(req)
|
||||||
if reqErr != nil {
|
if reqErr != nil {
|
||||||
|
@ -95,13 +101,14 @@ func SendBan(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI)
|
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Ban, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||||
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
|
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI) util.JSONResponse {
|
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, cryptoIDs bool) util.JSONResponse {
|
||||||
|
|
||||||
|
// TODO: cryptoIDs - what about when we don't know the senderID for a user?
|
||||||
event, err := buildMembershipEvent(
|
event, err := buildMembershipEvent(
|
||||||
ctx, targetUserID, reason, profileAPI, device, membership,
|
ctx, targetUserID, reason, profileAPI, device, membership,
|
||||||
roomID, false, cfg, evTime, rsAPI, asAPI,
|
roomID, false, cfg, evTime, rsAPI, asAPI,
|
||||||
|
@ -114,6 +121,8 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var json interface{}
|
||||||
|
if !cryptoIDs {
|
||||||
serverName := device.UserDomain()
|
serverName := device.UserDomain()
|
||||||
if err = roomserverAPI.SendEvents(
|
if err = roomserverAPI.SendEvents(
|
||||||
ctx, rsAPI,
|
ctx, rsAPI,
|
||||||
|
@ -131,10 +140,14 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic
|
||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
json = struct{}{}
|
||||||
|
} else {
|
||||||
|
json = membershipCryptoIDsResponse{PDU: event.JSON()}
|
||||||
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: struct{}{},
|
JSON: json,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,6 +155,7 @@ func SendKick(
|
||||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
body, evTime, reqErr := extractRequestData(req)
|
body, evTime, reqErr := extractRequestData(req)
|
||||||
if reqErr != nil {
|
if reqErr != nil {
|
||||||
|
@ -216,13 +230,14 @@ func SendKick(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: should we be using SendLeave instead?
|
// TODO: should we be using SendLeave instead?
|
||||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI)
|
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendUnban(
|
func SendUnban(
|
||||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
body, evTime, reqErr := extractRequestData(req)
|
body, evTime, reqErr := extractRequestData(req)
|
||||||
if reqErr != nil {
|
if reqErr != nil {
|
||||||
|
@ -272,13 +287,14 @@ func SendUnban(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: should we be using SendLeave instead?
|
// TODO: should we be using SendLeave instead?
|
||||||
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI)
|
return sendMembership(req.Context(), profileAPI, device, roomID, spec.Leave, body.Reason, cfg, body.UserID, evTime, rsAPI, asAPI, cryptoIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendInvite(
|
func SendInvite(
|
||||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
body, evTime, reqErr := extractRequestData(req)
|
body, evTime, reqErr := extractRequestData(req)
|
||||||
if reqErr != nil {
|
if reqErr != nil {
|
||||||
|
@ -323,7 +339,7 @@ func SendInvite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// We already received the return value, so no need to check for an error here.
|
// We already received the return value, so no need to check for an error here.
|
||||||
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime)
|
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime, cryptoIDs)
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,6 +352,7 @@ func sendInvite(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
|
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
|
||||||
|
cryptoIDs bool,
|
||||||
) (util.JSONResponse, error) {
|
) (util.JSONResponse, error) {
|
||||||
validRoomID, err := spec.NewRoomID(roomID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -372,7 +389,7 @@ func sendInvite(
|
||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
inviteEvent, err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||||
InviteInput: roomserverAPI.InviteInput{
|
InviteInput: roomserverAPI.InviteInput{
|
||||||
RoomID: *validRoomID,
|
RoomID: *validRoomID,
|
||||||
Inviter: *inviter,
|
Inviter: *inviter,
|
||||||
|
@ -387,7 +404,7 @@ func sendInvite(
|
||||||
},
|
},
|
||||||
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
|
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
|
||||||
SendAsServer: string(device.UserDomain()),
|
SendAsServer: string(device.UserDomain()),
|
||||||
})
|
}, cryptoIDs)
|
||||||
|
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case roomserverAPI.ErrInvalidID:
|
case roomserverAPI.ErrInvalidID:
|
||||||
|
@ -410,10 +427,22 @@ func sendInvite(
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
response := util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: struct{}{},
|
JSON: struct{}{},
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
type inviteCryptoIDResponse struct {
|
||||||
|
PDU json.RawMessage `json:"pdu"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if inviteEvent != nil {
|
||||||
|
response.JSON = inviteCryptoIDResponse{
|
||||||
|
PDU: json.RawMessage(inviteEvent.JSON()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildMembershipEventDirect(
|
func buildMembershipEventDirect(
|
||||||
|
|
|
@ -309,9 +309,33 @@ func Setup(
|
||||||
|
|
||||||
v3mux.Handle("/createRoom",
|
v3mux.Handle("/createRoom",
|
||||||
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI)
|
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI, false)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/createRoom",
|
||||||
|
httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/createRoom")
|
||||||
|
return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI, true)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}",
|
||||||
|
httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/send_pdus")
|
||||||
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
txnID := vars["txnID"]
|
||||||
|
|
||||||
|
// NOTE: when making events such as for create_room, multiple PDUs will need to be passed between the client & server.
|
||||||
|
return SendPDUs(req, device, cfg, userAPI, rsAPI, asAPI, &txnID)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/join/{roomIDOrAlias}",
|
v3mux.Handle("/join/{roomIDOrAlias}",
|
||||||
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.Limit(req, device); r != nil {
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
@ -334,8 +358,32 @@ func Setup(
|
||||||
return resp.(util.JSONResponse)
|
return resp.(util.JSONResponse)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/join/{roomIDOrAlias}",
|
||||||
|
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/join/{roomIDOrAlias}")
|
||||||
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
// Only execute a join for roomIDOrAlias and UserID once. If there is a join in progress
|
||||||
|
// it waits for it to complete and returns that result for subsequent requests.
|
||||||
|
resp, _, _ := sf.Do(vars["roomIDOrAlias"]+device.UserID, func() (any, error) {
|
||||||
|
return JoinRoomByIDOrAliasCryptoIDs(
|
||||||
|
req, device, rsAPI, userAPI, vars["roomIDOrAlias"],
|
||||||
|
), nil
|
||||||
|
})
|
||||||
|
// once all joins are processed, drop them from the cache. Further requests
|
||||||
|
// will be processed as usual.
|
||||||
|
sf.Forget(vars["roomIDOrAlias"] + device.UserID)
|
||||||
|
return resp.(util.JSONResponse)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
if mscCfg.Enabled("msc2753") {
|
if mscCfg.Enabled("msc2753") {
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/peek/{roomIDOrAlias}",
|
v3mux.Handle("/peek/{roomIDOrAlias}",
|
||||||
httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI(spec.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.Limit(req, device); r != nil {
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
@ -378,6 +426,29 @@ func Setup(
|
||||||
return resp.(util.JSONResponse)
|
return resp.(util.JSONResponse)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/join",
|
||||||
|
httputil.MakeAuthAPI(spec.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/join")
|
||||||
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
// Only execute a join for roomID and UserID once. If there is a join in progress
|
||||||
|
// it waits for it to complete and returns that result for subsequent requests.
|
||||||
|
resp, _, _ := sf.Do(vars["roomID"]+device.UserID, func() (any, error) {
|
||||||
|
return JoinRoomByIDOrAliasCryptoIDs(
|
||||||
|
req, device, rsAPI, userAPI, vars["roomID"],
|
||||||
|
), nil
|
||||||
|
})
|
||||||
|
// once all joins are processed, drop them from the cache. Further requests
|
||||||
|
// will be processed as usual.
|
||||||
|
sf.Forget(vars["roomID"] + device.UserID)
|
||||||
|
return resp.(util.JSONResponse)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/leave",
|
v3mux.Handle("/rooms/{roomID}/leave",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.Limit(req, device); r != nil {
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
@ -388,10 +459,26 @@ func Setup(
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return LeaveRoomByID(
|
return LeaveRoomByID(
|
||||||
req, device, rsAPI, vars["roomID"],
|
req, device, rsAPI, vars["roomID"], false,
|
||||||
)
|
)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/leave",
|
||||||
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/leave")
|
||||||
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return LeaveRoomByID(
|
||||||
|
req, device, rsAPI, vars["roomID"], true,
|
||||||
|
)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/rooms/{roomID}/unpeek",
|
v3mux.Handle("/rooms/{roomID}/unpeek",
|
||||||
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -409,7 +496,17 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/ban",
|
||||||
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/ban")
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/invite",
|
v3mux.Handle("/rooms/{roomID}/invite",
|
||||||
|
@ -421,7 +518,20 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/invite",
|
||||||
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/invite")
|
||||||
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/kick",
|
v3mux.Handle("/rooms/{roomID}/kick",
|
||||||
|
@ -430,7 +540,17 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/kick",
|
||||||
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/kick")
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/unban",
|
v3mux.Handle("/rooms/{roomID}/unban",
|
||||||
|
@ -439,7 +559,17 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/unban",
|
||||||
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/unban")
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
|
v3mux.Handle("/rooms/{roomID}/send/{eventType}",
|
||||||
|
@ -451,6 +581,16 @@ func Setup(
|
||||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
|
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/send/{eventType}",
|
||||||
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/send/{eventType}")
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -462,6 +602,18 @@ func Setup(
|
||||||
nil, cfg, rsAPI, transactionsCache)
|
nil, cfg, rsAPI, transactionsCache)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||||
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/send/{eventType}/{txnID}")
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
txnID := vars["txnID"]
|
||||||
|
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], &txnID,
|
||||||
|
nil, cfg, rsAPI, transactionsCache)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -510,6 +662,18 @@ func Setup(
|
||||||
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
||||||
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/state/{eventType:[^/]+/?}")
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
emptyString := ""
|
||||||
|
eventType := strings.TrimSuffix(vars["eventType"], "/")
|
||||||
|
return SendEventCryptoIDs(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -521,6 +685,17 @@ func Setup(
|
||||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||||
|
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/rooms/{roomID}/state/{eventType}/{stateKey}")
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
stateKey := vars["stateKey"]
|
||||||
|
return SendEventCryptoIDs(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
// Defined outside of handler to persist between calls
|
// Defined outside of handler to persist between calls
|
||||||
// TODO: clear based on some criteria
|
// TODO: clear based on some criteria
|
||||||
|
@ -559,6 +734,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/directory/room/{roomAlias}",
|
v3mux.Handle("/directory/room/{roomAlias}",
|
||||||
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -569,6 +745,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/directory/room/{roomAlias}",
|
v3mux.Handle("/directory/room/{roomAlias}",
|
||||||
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -636,6 +813,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/rooms/{roomID}/typing/{userID}",
|
v3mux.Handle("/rooms/{roomID}/typing/{userID}",
|
||||||
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.Limit(req, device); r != nil {
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
@ -648,6 +826,7 @@ func Setup(
|
||||||
return SendTyping(req, device, vars["roomID"], vars["userID"], rsAPI, syncProducer)
|
return SendTyping(req, device, vars["roomID"], vars["userID"], rsAPI, syncProducer)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
|
v3mux.Handle("/rooms/{roomID}/redact/{eventID}",
|
||||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -657,6 +836,7 @@ func Setup(
|
||||||
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
|
return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}",
|
||||||
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -668,6 +848,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
v3mux.Handle("/sendToDevice/{eventType}/{txnID}",
|
||||||
httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -1118,6 +1299,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/rooms/{roomID}/read_markers",
|
v3mux.Handle("/rooms/{roomID}/read_markers",
|
||||||
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.Limit(req, device); r != nil {
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
@ -1144,6 +1326,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/rooms/{roomID}/upgrade",
|
v3mux.Handle("/rooms/{roomID}/upgrade",
|
||||||
httputil.MakeAuthAPI("rooms_upgrade", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_upgrade", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
@ -1472,6 +1655,12 @@ func Setup(
|
||||||
return UploadKeys(req, userAPI, device)
|
return UploadKeys(req, userAPI, device)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc4080/keys/upload",
|
||||||
|
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc4080/keys/upload")
|
||||||
|
return UploadKeysCryptoIDs(req, userAPI, device)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/keys/query",
|
v3mux.Handle("/keys/query",
|
||||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return QueryKeys(req, userAPI, device)
|
return QueryKeys(req, userAPI, device)
|
||||||
|
@ -1495,6 +1684,7 @@ func Setup(
|
||||||
return SetReceipt(req, userAPI, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
|
return SetReceipt(req, userAPI, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
// TODO: update for cryptoIDs
|
||||||
v3mux.Handle("/presence/{userId}/status",
|
v3mux.Handle("/presence/{userId}/status",
|
||||||
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
|
272
clientapi/routing/send_pdus.go
Normal file
272
clientapi/routing/send_pdus.go
Normal file
|
@ -0,0 +1,272 @@
|
||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PDUInfo struct {
|
||||||
|
Version string `json:"room_version"`
|
||||||
|
ViaServer string `json:"via_server,omitempty"`
|
||||||
|
PDU json.RawMessage `json:"pdu"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type sendPDUsRequest struct {
|
||||||
|
PDUs []PDUInfo `json:"pdus"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPDUs implements /sendPDUs
|
||||||
|
// nolint:gocyclo
|
||||||
|
func SendPDUs(
|
||||||
|
req *http.Request, device *api.Device,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
profileAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
|
asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
txnID *string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
// TODO: cryptoIDs - should this include an "eventType"?
|
||||||
|
// if it's a bulk send endpoint, I don't think that makes any sense since there are multiple event types
|
||||||
|
// In that case, how do I know how to treat the events?
|
||||||
|
// I could sort them all by roomID?
|
||||||
|
// Then filter them down based on event type? (how do I collect groups of events such as for room creation?)
|
||||||
|
// Possibly based on event hash tracking that I know were sent to the client?
|
||||||
|
// For createRoom, I know what the possible list of events are, so try to find those and collect them to send to room creation.
|
||||||
|
// Could also sort by depth... but that seems dangerous and depth may not be a field forever
|
||||||
|
// Does it matter at all?
|
||||||
|
// Can't I just forward all the events to the roomserver?
|
||||||
|
// Do I need to do any specific processing on them?
|
||||||
|
|
||||||
|
var pdus sendPDUsRequest
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, &pdus)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := spec.NewUserID(device.UserID, true)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.InvalidParam(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a mutex for the specific user in the specific room
|
||||||
|
// this avoids a situation where events that are received in quick succession are sent to the roomserver in a jumbled order
|
||||||
|
// TODO: cryptoIDs - where to get roomID from?
|
||||||
|
mutex, _ := userRoomSendMutexes.LoadOrStore("roomID"+userID.String(), &sync.Mutex{})
|
||||||
|
mutex.(*sync.Mutex).Lock()
|
||||||
|
defer mutex.(*sync.Mutex).Unlock()
|
||||||
|
|
||||||
|
inputs := make([]roomserverAPI.InputRoomEvent, 0, len(pdus.PDUs))
|
||||||
|
for _, event := range pdus.PDUs {
|
||||||
|
// TODO: cryptoIDs - event hash check?
|
||||||
|
verImpl, err := gomatrixserverlib.GetRoomVersion(gomatrixserverlib.RoomVersion(event.Version))
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{Err: err.Error()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//eventJSON, err := json.Marshal(event)
|
||||||
|
//if err != nil {
|
||||||
|
// return util.JSONResponse{
|
||||||
|
// Code: http.StatusInternalServerError,
|
||||||
|
// JSON: spec.InternalServerError{Err: err.Error()},
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
// TODO: cryptoIDs - how should we be converting to a PDU here?
|
||||||
|
// if the hash matches an event we sent to the client, then the JSON should be good.
|
||||||
|
// But how do we know how to fill out if the event is redacted if we use the trustedJSON function?
|
||||||
|
// Also - untrusted JSON seems better - except it strips off the unsigned field?
|
||||||
|
// Also - gmsl events don't store the `hashes` field... problem?
|
||||||
|
|
||||||
|
pdu, err := verImpl.NewEventFromUntrustedJSON(event.PDU)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{Err: err.Error()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
util.GetLogger(req.Context()).Infof("Processing %s event (%s): %s", pdu.Type(), pdu.EventID(), pdu.JSON())
|
||||||
|
|
||||||
|
// Check that the event is signed by the server sending the request.
|
||||||
|
redacted, err := verImpl.RedactEventJSON(pdu.JSON())
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("RedactEventJSON failed")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := gomatrixserverlib.JSONVerifierSelf{}
|
||||||
|
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
|
||||||
|
ServerName: spec.ServerName(pdu.SenderID()),
|
||||||
|
Message: redacted,
|
||||||
|
AtTS: pdu.OriginServerTS(),
|
||||||
|
ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck,
|
||||||
|
}}
|
||||||
|
verifyResults, err := verifier.VerifyJSONs(req.Context(), verifyRequests)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("keys.VerifyJSONs failed")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if verifyResults[0].Error != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(verifyResults[0].Error).Error("Signature check failed: ")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch pdu.Type() {
|
||||||
|
case spec.MRoomCreate:
|
||||||
|
case spec.MRoomMember:
|
||||||
|
var membership gomatrixserverlib.MemberContent
|
||||||
|
err = json.Unmarshal(pdu.Content(), &membership)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
util.GetLogger(req.Context()).Errorf("m.room.member event (%s) content invalid: %v", pdu.EventID(), pdu.Content())
|
||||||
|
continue
|
||||||
|
case membership.Membership == spec.Join:
|
||||||
|
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !cfg.Matrix.IsLocalServerName(pdu.RoomID().Domain()) {
|
||||||
|
queryReq := roomserverAPI.QueryMembershipForUserRequest{
|
||||||
|
RoomID: pdu.RoomID().String(),
|
||||||
|
UserID: *deviceUserID,
|
||||||
|
}
|
||||||
|
var queryRes roomserverAPI.QueryMembershipForUserResponse
|
||||||
|
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !queryRes.IsInRoom {
|
||||||
|
// This is a join event to a remote room
|
||||||
|
// TODO: cryptoIDs - figure out how to obtain unsigned contents for outstanding federated invites
|
||||||
|
joinReq := roomserverAPI.PerformJoinRequestCryptoIDs{
|
||||||
|
RoomID: pdu.RoomID().String(),
|
||||||
|
UserID: device.UserID,
|
||||||
|
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||||
|
ServerNames: []spec.ServerName{spec.ServerName(event.ViaServer)},
|
||||||
|
JoinEvent: pdu,
|
||||||
|
}
|
||||||
|
err := rsAPI.PerformSendJoinCryptoIDs(req.Context(), &joinReq)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err)
|
||||||
|
}
|
||||||
|
continue // NOTE: don't send this event on to the roomserver
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case membership.Membership == spec.Invite:
|
||||||
|
stateKey := pdu.StateKey()
|
||||||
|
if stateKey == nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: spec.Forbidden("invalid state_key for membership event"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
invitedUserID, err := rsAPI.QueryUserIDForSender(req.Context(), pdu.RoomID(), spec.SenderID(*stateKey))
|
||||||
|
if err != nil || invitedUserID == nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: spec.NotFound("cannot find userID for invite event"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !cfg.Matrix.IsLocalServerName(spec.ServerName(invitedUserID.Domain())) {
|
||||||
|
inviteReq := roomserverAPI.PerformInviteRequestCryptoIDs{
|
||||||
|
RoomID: pdu.RoomID().String(),
|
||||||
|
UserID: *invitedUserID,
|
||||||
|
InviteEvent: pdu,
|
||||||
|
}
|
||||||
|
err := rsAPI.PerformSendInviteCryptoIDs(req.Context(), &inviteReq)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).Errorf("Failed processing %s event (%s): %v", pdu.Type(), pdu.EventID(), err)
|
||||||
|
}
|
||||||
|
continue // NOTE: don't send this event on to the roomserver
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: cryptoIDs - does it matter which order these are added?
|
||||||
|
// yes - if the events are for room creation.
|
||||||
|
// Could make this a client requirement? ie. events are processed based on the order they appear
|
||||||
|
// We need to check event validity after processing each event.
|
||||||
|
// ie. what if the client changes power levels that disallow further events they sent?
|
||||||
|
// We should be doing this already as part of `SendInputRoomEvents`, but how should we pass this
|
||||||
|
// failure back to the client?
|
||||||
|
|
||||||
|
var transactionID *roomserverAPI.TransactionID
|
||||||
|
if txnID != nil {
|
||||||
|
transactionID = &roomserverAPI.TransactionID{
|
||||||
|
SessionID: device.SessionID, TransactionID: *txnID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs = append(inputs, roomserverAPI.InputRoomEvent{
|
||||||
|
Kind: roomserverAPI.KindNew,
|
||||||
|
Event: &types.HeaderedEvent{PDU: pdu},
|
||||||
|
Origin: userID.Domain(),
|
||||||
|
// TODO: cryptoIDs - what to do with this field?
|
||||||
|
// should probably generate this based on the event type being sent?
|
||||||
|
//SendAsServer: roomserverAPI.DoNotSendToOtherServers,
|
||||||
|
TransactionID: transactionID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
startedSubmittingEvents := time.Now()
|
||||||
|
// send the events to the roomserver
|
||||||
|
if err := roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, userID.Domain(), inputs, false); err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{Err: err.Error()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timeToSubmitEvents := time.Since(startedSubmittingEvents)
|
||||||
|
sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvents.Milliseconds()))
|
||||||
|
|
||||||
|
// Add response to transactionsCache
|
||||||
|
if txnID != nil {
|
||||||
|
// TODO : cryptoIDs - fix this
|
||||||
|
//res := util.JSONResponse{
|
||||||
|
// Code: http.StatusOK,
|
||||||
|
// JSON: sendEventResponse{e.EventID()},
|
||||||
|
//}
|
||||||
|
//txnCache.AddTransaction(device.AccessToken, *txnID, req.URL, &res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
}
|
||||||
|
}
|
|
@ -32,6 +32,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
@ -44,6 +45,11 @@ type sendEventResponse struct {
|
||||||
EventID string `json:"event_id"`
|
EventID string `json:"event_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type sendEventResponseCryptoIDs struct {
|
||||||
|
EventID string `json:"event_id"`
|
||||||
|
PDU json.RawMessage `json:"pdu"`
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
|
userRoomSendMutexes sync.Map // (roomID+userID) -> mutex. mutexes to ensure correct ordering of sendEvents
|
||||||
)
|
)
|
||||||
|
@ -94,7 +100,7 @@ func SendEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil {
|
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
|
||||||
parsedRoomID, innerErr := spec.NewRoomID(roomID)
|
parsedRoomID, innerErr := spec.NewRoomID(roomID)
|
||||||
if innerErr != nil {
|
if innerErr != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -148,7 +154,7 @@ func SendEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// for power level events we need to replace the userID with the pseudoID
|
// for power level events we need to replace the userID with the pseudoID
|
||||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels {
|
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
|
||||||
err = updatePowerLevels(req, r, roomID, rsAPI)
|
err = updatePowerLevels(req, r, roomID, rsAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -166,7 +172,7 @@ func SendEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime)
|
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime, false)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
@ -262,6 +268,156 @@ func SendEvent(
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendEventCryptoIDs implements:
|
||||||
|
//
|
||||||
|
// /rooms/{roomID}/send/{eventType}
|
||||||
|
// /rooms/{roomID}/send/{eventType}/{txnID}
|
||||||
|
// /rooms/{roomID}/state/{eventType}/{stateKey}
|
||||||
|
//
|
||||||
|
// nolint: gocyclo
|
||||||
|
func SendEventCryptoIDs(
|
||||||
|
req *http.Request,
|
||||||
|
device *userapi.Device,
|
||||||
|
roomID, eventType string, txnID, stateKey *string,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
rsAPI api.ClientRoomserverAPI,
|
||||||
|
txnCache *transactions.Cache,
|
||||||
|
) util.JSONResponse {
|
||||||
|
roomVersion, err := rsAPI.QueryRoomVersionForRoom(req.Context(), roomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.UnsupportedRoomVersion(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if txnID != nil {
|
||||||
|
// Try to fetch response from transactionsCache
|
||||||
|
if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID, req.URL); ok {
|
||||||
|
return *res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||||
|
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
|
||||||
|
parsedRoomID, innerErr := spec.NewRoomID(roomID)
|
||||||
|
if innerErr != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.InvalidParam("invalid room ID"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newStateKey, innerErr := synctypes.FromClientStateKey(*parsedRoomID, *stateKey, func(roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
|
||||||
|
return rsAPI.QuerySenderIDForUser(req.Context(), roomID, userID)
|
||||||
|
})
|
||||||
|
if innerErr != nil {
|
||||||
|
// TODO: work out better logic for failure cases (e.g. sender ID not found)
|
||||||
|
util.GetLogger(req.Context()).WithError(innerErr).Error("synctypes.FromClientStateKey failed")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.Unknown("internal server error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stateKey = newStateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
var r map[string]interface{} // must be a JSON object
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if stateKey != nil {
|
||||||
|
// If the existing/new state content are equal, return the existing event_id, making the request idempotent.
|
||||||
|
if resp := stateEqual(req.Context(), rsAPI, eventType, *stateKey, roomID, r); resp != nil {
|
||||||
|
return *resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
startedGeneratingEvent := time.Now()
|
||||||
|
|
||||||
|
// If we're sending a membership update, make sure to strip the authorised
|
||||||
|
// via key if it is present, otherwise other servers won't be able to auth
|
||||||
|
// the event if the room is set to the "restricted" join rule.
|
||||||
|
if eventType == spec.MRoomMember {
|
||||||
|
delete(r, "join_authorised_via_users_server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// for power level events we need to replace the userID with the pseudoID
|
||||||
|
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
|
||||||
|
err = updatePowerLevels(req, r, roomID, rsAPI)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{Err: err.Error()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
evTime, err := httputil.ParseTSParam(req)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.InvalidParam(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, rsAPI, evTime, true)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
timeToGenerateEvent := time.Since(startedGeneratingEvent)
|
||||||
|
|
||||||
|
// validate that the aliases exists
|
||||||
|
if eventType == spec.MRoomCanonicalAlias && stateKey != nil && *stateKey == "" {
|
||||||
|
aliasReq := api.AliasEvent{}
|
||||||
|
if err = json.Unmarshal(e.Content(), &aliasReq); err != nil {
|
||||||
|
return util.ErrorResponse(fmt.Errorf("unable to parse alias event: %w", err))
|
||||||
|
}
|
||||||
|
if !aliasReq.Valid() {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.InvalidParam("Request contains invalid aliases."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
aliasRes := &api.GetAliasesForRoomIDResponse{}
|
||||||
|
if err = rsAPI.GetAliasesForRoomID(req.Context(), &api.GetAliasesForRoomIDRequest{RoomID: roomID}, aliasRes); err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var found int
|
||||||
|
requestAliases := append(aliasReq.AltAliases, aliasReq.Alias)
|
||||||
|
for _, alias := range aliasRes.Aliases {
|
||||||
|
for _, altAlias := range requestAliases {
|
||||||
|
if altAlias == alias {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check that we found at least the same amount of existing aliases as are in the request
|
||||||
|
if aliasReq.Alias != "" && found < len(requestAliases) {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadAlias("No matching alias found."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds()))
|
||||||
|
|
||||||
|
res := util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: sendEventResponseCryptoIDs{
|
||||||
|
EventID: e.EventID(),
|
||||||
|
PDU: json.RawMessage(e.JSON()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error {
|
func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID string, rsAPI api.ClientRoomserverAPI) error {
|
||||||
users, ok := r["users"]
|
users, ok := r["users"]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -329,6 +485,7 @@ func generateSendEvent(
|
||||||
roomID, eventType string, stateKey *string,
|
roomID, eventType string, stateKey *string,
|
||||||
rsAPI api.ClientRoomserverAPI,
|
rsAPI api.ClientRoomserverAPI,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
|
cryptoIDs bool,
|
||||||
) (gomatrixserverlib.PDU, *util.JSONResponse) {
|
) (gomatrixserverlib.PDU, *util.JSONResponse) {
|
||||||
// parse the incoming http request
|
// parse the incoming http request
|
||||||
fullUserID, err := spec.NewUserID(device.UserID, true)
|
fullUserID, err := spec.NewUserID(device.UserID, true)
|
||||||
|
@ -376,13 +533,19 @@ func generateSendEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID)
|
var identity fclient.SigningIdentity
|
||||||
if err != nil {
|
if !cryptoIDs {
|
||||||
|
id, idErr := rsAPI.SigningIdentityFor(ctx, *validRoomID, *fullUserID)
|
||||||
|
if idErr != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
identity = id
|
||||||
|
} else {
|
||||||
|
identity.ServerName = spec.ServerName(*senderID)
|
||||||
|
}
|
||||||
|
|
||||||
var queryRes api.QueryLatestEventsAndStateResponse
|
var queryRes api.QueryLatestEventsAndStateResponse
|
||||||
e, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, &queryRes)
|
e, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, evTime, rsAPI, &queryRes)
|
||||||
|
|
|
@ -169,7 +169,7 @@ func SendServerNotice(
|
||||||
PowerLevelContentOverride: pl,
|
PowerLevelContentOverride: pl,
|
||||||
}
|
}
|
||||||
|
|
||||||
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, userAPI, rsAPI, asAPI, time.Now())
|
roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, userAPI, rsAPI, asAPI, time.Now(), false)
|
||||||
|
|
||||||
switch data := roomRes.JSON.(type) {
|
switch data := roomRes.JSON.(type) {
|
||||||
case createRoomResponse:
|
case createRoomResponse:
|
||||||
|
@ -215,7 +215,7 @@ func SendServerNotice(
|
||||||
}
|
}
|
||||||
if !membershipRes.IsInRoom {
|
if !membershipRes.IsInRoom {
|
||||||
// re-invite the user
|
// re-invite the user
|
||||||
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
|
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
@ -228,7 +228,7 @@ func SendServerNotice(
|
||||||
"body": r.Content.Body,
|
"body": r.Content.Body,
|
||||||
"msgtype": r.Content.MsgType,
|
"msgtype": r.Content.MsgType,
|
||||||
}
|
}
|
||||||
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now())
|
e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, rsAPI, time.Now(), false)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
logrus.Errorf("failed to send message: %+v", resErr)
|
logrus.Errorf("failed to send message: %+v", resErr)
|
||||||
return *resErr
|
return *resErr
|
||||||
|
|
|
@ -198,6 +198,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
||||||
// state to see if there is an event with that type and state key, if there
|
// state to see if there is an event with that type and state key, if there
|
||||||
// is then (by default) we return the content, otherwise a 404.
|
// is then (by default) we return the content, otherwise a 404.
|
||||||
// If eventFormat=true, sends the whole event else just the content.
|
// If eventFormat=true, sends the whole event else just the content.
|
||||||
|
// nolint:gocyclo
|
||||||
func OnIncomingStateTypeRequest(
|
func OnIncomingStateTypeRequest(
|
||||||
ctx context.Context, device *userapi.Device, rsAPI api.ClientRoomserverAPI,
|
ctx context.Context, device *userapi.Device, rsAPI api.ClientRoomserverAPI,
|
||||||
roomID, evType, stateKey string, eventFormat bool,
|
roomID, evType, stateKey string, eventFormat bool,
|
||||||
|
@ -214,7 +215,7 @@ func OnIncomingStateTypeRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||||
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs {
|
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs || roomVer == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||||
parsedRoomID, err := spec.NewRoomID(roomID)
|
parsedRoomID, err := spec.NewRoomID(roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
|
@ -58,12 +58,18 @@ type RoomserverFederationAPI interface {
|
||||||
PerformDirectoryLookup(ctx context.Context, request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse) error
|
PerformDirectoryLookup(ctx context.Context, request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse) error
|
||||||
// Handle an instruction to make_join & send_join with a remote server.
|
// Handle an instruction to make_join & send_join with a remote server.
|
||||||
PerformJoin(ctx context.Context, request *PerformJoinRequest, response *PerformJoinResponse)
|
PerformJoin(ctx context.Context, request *PerformJoinRequest, response *PerformJoinResponse)
|
||||||
|
PerformMakeJoin(ctx context.Context, request *PerformJoinRequest) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error)
|
||||||
|
PerformSendJoin(ctx context.Context, request *PerformSendJoinRequestCryptoIDs, response *PerformJoinResponse)
|
||||||
// Handle an instruction to make_leave & send_leave with a remote server.
|
// Handle an instruction to make_leave & send_leave with a remote server.
|
||||||
PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse) error
|
PerformLeave(ctx context.Context, request *PerformLeaveRequest, response *PerformLeaveResponse, cryptoIDs bool) error
|
||||||
// Handle sending an invite to a remote server.
|
// Handle sending an invite to a remote server.
|
||||||
SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
SendInvite(ctx context.Context, event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
||||||
// Handle sending an invite to a remote server.
|
// Handle sending an invite to a remote server.
|
||||||
SendInviteV3(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
SendInviteV3(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
||||||
|
// Handle sending an invite to a remote server.
|
||||||
|
MakeInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error)
|
||||||
|
// Handle sending an invite to a remote server.
|
||||||
|
SendInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.PDU, invitee spec.UserID, version gomatrixserverlib.RoomVersion) error
|
||||||
// Handle an instruction to peek a room on a remote server.
|
// Handle an instruction to peek a room on a remote server.
|
||||||
PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error
|
PerformOutboundPeek(ctx context.Context, request *PerformOutboundPeekRequest, response *PerformOutboundPeekResponse) error
|
||||||
// Query the server names of the joined hosts in a room.
|
// Query the server names of the joined hosts in a room.
|
||||||
|
@ -168,6 +174,15 @@ type PerformJoinRequest struct {
|
||||||
Unsigned map[string]interface{} `json:"unsigned"`
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PerformSendJoinRequestCryptoIDs struct {
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
|
||||||
|
ServerNames types.ServerNames `json:"server_names"`
|
||||||
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
|
Event gomatrixserverlib.PDU
|
||||||
|
}
|
||||||
|
|
||||||
type PerformJoinResponse struct {
|
type PerformJoinResponse struct {
|
||||||
JoinedVia spec.ServerName
|
JoinedVia spec.ServerName
|
||||||
LastError *gomatrix.HTTPError
|
LastError *gomatrix.HTTPError
|
||||||
|
|
|
@ -239,6 +239,319 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerformMakeJoin implements api.FederationInternalAPI
|
||||||
|
func (r *FederationInternalAPI) PerformMakeJoin(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.PerformJoinRequest,
|
||||||
|
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||||
|
// Check that a join isn't already in progress for this user/room.
|
||||||
|
j := federatedJoin{request.UserID, request.RoomID}
|
||||||
|
if _, found := r.joins.Load(j); found {
|
||||||
|
return nil, "", "", &gomatrix.HTTPError{
|
||||||
|
Code: 429,
|
||||||
|
Message: `{
|
||||||
|
"errcode": "M_LIMIT_EXCEEDED",
|
||||||
|
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
|
||||||
|
}`, // TODO: Why do none of our error types play nicely with each other?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.joins.Store(j, nil)
|
||||||
|
defer r.joins.Delete(j)
|
||||||
|
|
||||||
|
// Deduplicate the server names we were provided but keep the ordering
|
||||||
|
// as this encodes useful information about which servers are most likely
|
||||||
|
// to respond.
|
||||||
|
seenSet := make(map[spec.ServerName]bool)
|
||||||
|
var uniqueList []spec.ServerName
|
||||||
|
for _, srv := range request.ServerNames {
|
||||||
|
if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenSet[srv] = true
|
||||||
|
uniqueList = append(uniqueList, srv)
|
||||||
|
}
|
||||||
|
request.ServerNames = uniqueList
|
||||||
|
|
||||||
|
// Try each server that we were provided until we land on one that
|
||||||
|
// successfully completes the make-join send-join dance.
|
||||||
|
var lastErr error
|
||||||
|
for _, serverName := range request.ServerNames {
|
||||||
|
var joinEvent gomatrixserverlib.PDU
|
||||||
|
var roomVersion gomatrixserverlib.RoomVersion
|
||||||
|
var err error
|
||||||
|
if joinEvent, roomVersion, _, err = r.performMakeJoinUsingServer(
|
||||||
|
ctx,
|
||||||
|
request.RoomID,
|
||||||
|
request.UserID,
|
||||||
|
request.Content,
|
||||||
|
serverName,
|
||||||
|
); err != nil {
|
||||||
|
logrus.WithError(err).WithFields(logrus.Fields{
|
||||||
|
"server_name": serverName,
|
||||||
|
"room_id": request.RoomID,
|
||||||
|
}).Warnf("Failed to join room through server")
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We're all good.
|
||||||
|
return joinEvent, roomVersion, serverName, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reach here then we didn't complete a join for some reason.
|
||||||
|
var httpErr gomatrix.HTTPError
|
||||||
|
var lastError *gomatrix.HTTPError
|
||||||
|
if ok := errors.As(lastErr, &httpErr); ok {
|
||||||
|
httpErr.Message = string(httpErr.Contents)
|
||||||
|
lastError = &httpErr
|
||||||
|
} else {
|
||||||
|
lastError = &gomatrix.HTTPError{
|
||||||
|
Code: 0,
|
||||||
|
WrappedError: nil,
|
||||||
|
Message: "Unknown HTTP error",
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
lastError.Message = lastErr.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Errorf(
|
||||||
|
"failed to join user %q to room %q through %d server(s): last error %s",
|
||||||
|
request.UserID, request.RoomID, len(request.ServerNames), lastError,
|
||||||
|
)
|
||||||
|
return nil, "", "", lastError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *FederationInternalAPI) performMakeJoinUsingServer(
|
||||||
|
ctx context.Context,
|
||||||
|
roomID, userID string,
|
||||||
|
content map[string]interface{},
|
||||||
|
serverName spec.ServerName,
|
||||||
|
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.SenderID, error) {
|
||||||
|
if !r.shouldAttemptDirectFederation(serverName) {
|
||||||
|
return nil, "", "", fmt.Errorf("relay servers have no meaningful response for join.")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := spec.NewUserID(userID, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", err
|
||||||
|
}
|
||||||
|
room, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
joinInput := gomatrixserverlib.PerformMakeJoinInput{
|
||||||
|
UserID: user,
|
||||||
|
RoomID: room,
|
||||||
|
ServerName: serverName,
|
||||||
|
Content: content,
|
||||||
|
PrivateKey: r.cfg.Matrix.PrivateKey,
|
||||||
|
KeyID: r.cfg.Matrix.KeyID,
|
||||||
|
KeyRing: r.keyRing,
|
||||||
|
GetOrCreateSenderID: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) {
|
||||||
|
// assign a roomNID, otherwise we can't create a private key for the user
|
||||||
|
_, nidErr := r.rsAPI.AssignRoomNID(ctx, roomID, gomatrixserverlib.RoomVersion(roomVersion))
|
||||||
|
if nidErr != nil {
|
||||||
|
return "", nil, nidErr
|
||||||
|
}
|
||||||
|
key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
|
||||||
|
if keyErr != nil {
|
||||||
|
return "", nil, keyErr
|
||||||
|
}
|
||||||
|
return spec.SenderIDFromPseudoIDKey(key), key, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
joinEvent, version, senderID, joinErr := gomatrixserverlib.PerformMakeJoin(ctx, r, joinInput)
|
||||||
|
|
||||||
|
if joinErr != nil {
|
||||||
|
if !joinErr.Reachable {
|
||||||
|
r.statistics.ForServer(joinErr.ServerName).Failure()
|
||||||
|
} else {
|
||||||
|
r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect)
|
||||||
|
}
|
||||||
|
return nil, "", "", joinErr.Err
|
||||||
|
}
|
||||||
|
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
|
||||||
|
if joinEvent == nil {
|
||||||
|
return nil, "", "", fmt.Errorf("Received nil joinEvent response from gomatrixserverlib.PerformJoin")
|
||||||
|
}
|
||||||
|
|
||||||
|
return joinEvent, version, senderID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformSendJoin implements api.FederationInternalAPI
|
||||||
|
func (r *FederationInternalAPI) PerformSendJoin(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.PerformSendJoinRequestCryptoIDs,
|
||||||
|
response *api.PerformJoinResponse,
|
||||||
|
) {
|
||||||
|
// Check that a join isn't already in progress for this user/room.
|
||||||
|
j := federatedJoin{request.UserID, request.RoomID}
|
||||||
|
if _, found := r.joins.Load(j); found {
|
||||||
|
response.LastError = &gomatrix.HTTPError{
|
||||||
|
Code: 429,
|
||||||
|
Message: `{
|
||||||
|
"errcode": "M_LIMIT_EXCEEDED",
|
||||||
|
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
|
||||||
|
}`,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.joins.Store(j, nil)
|
||||||
|
defer r.joins.Delete(j)
|
||||||
|
|
||||||
|
// Deduplicate the server names we were provided but keep the ordering
|
||||||
|
// as this encodes useful information about which servers are most likely
|
||||||
|
// to respond.
|
||||||
|
seenSet := make(map[spec.ServerName]bool)
|
||||||
|
var uniqueList []spec.ServerName
|
||||||
|
for _, srv := range request.ServerNames {
|
||||||
|
if seenSet[srv] || r.cfg.Matrix.IsLocalServerName(srv) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenSet[srv] = true
|
||||||
|
uniqueList = append(uniqueList, srv)
|
||||||
|
}
|
||||||
|
request.ServerNames = uniqueList
|
||||||
|
|
||||||
|
// Try each server that we were provided until we land on one that
|
||||||
|
// successfully completes the make-join send-join dance.
|
||||||
|
var lastErr error
|
||||||
|
for _, serverName := range request.ServerNames {
|
||||||
|
if err := r.performSendJoinUsingServer(
|
||||||
|
ctx,
|
||||||
|
request.RoomID,
|
||||||
|
request.UserID,
|
||||||
|
request.Unsigned,
|
||||||
|
request.Event,
|
||||||
|
serverName,
|
||||||
|
); err != nil {
|
||||||
|
logrus.WithError(err).WithFields(logrus.Fields{
|
||||||
|
"server_name": serverName,
|
||||||
|
"room_id": request.RoomID,
|
||||||
|
}).Warnf("Failed to join room through server")
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We're all good.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reach here then we didn't complete a join for some reason.
|
||||||
|
var httpErr gomatrix.HTTPError
|
||||||
|
var lastError *gomatrix.HTTPError
|
||||||
|
if ok := errors.As(lastErr, &httpErr); ok {
|
||||||
|
httpErr.Message = string(httpErr.Contents)
|
||||||
|
lastError = &httpErr
|
||||||
|
} else {
|
||||||
|
lastError = &gomatrix.HTTPError{
|
||||||
|
Code: 0,
|
||||||
|
WrappedError: nil,
|
||||||
|
Message: "Unknown HTTP error",
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
lastError.Message = lastErr.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Errorf(
|
||||||
|
"failed to join user %q to room %q through %d server(s): last error %s",
|
||||||
|
request.UserID, request.RoomID, len(request.ServerNames), lastError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *FederationInternalAPI) performSendJoinUsingServer(
|
||||||
|
ctx context.Context,
|
||||||
|
roomID, userID string,
|
||||||
|
unsigned map[string]interface{},
|
||||||
|
event gomatrixserverlib.PDU,
|
||||||
|
serverName spec.ServerName,
|
||||||
|
) error {
|
||||||
|
user, err := spec.NewUserID(userID, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
room, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *room, *user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
joinInput := gomatrixserverlib.PerformSendJoinInput{
|
||||||
|
RoomID: room,
|
||||||
|
ServerName: serverName,
|
||||||
|
Unsigned: unsigned,
|
||||||
|
Origin: user.Domain(),
|
||||||
|
SenderID: *senderID,
|
||||||
|
KeyRing: r.keyRing,
|
||||||
|
Event: event,
|
||||||
|
RoomVersion: event.Version(),
|
||||||
|
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
|
}),
|
||||||
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
|
},
|
||||||
|
StoreSenderIDFromPublicID: func(ctx context.Context, senderID spec.SenderID, userIDRaw string, roomID spec.RoomID) error {
|
||||||
|
storeUserID, userErr := spec.NewUserID(userIDRaw, true)
|
||||||
|
if userErr != nil {
|
||||||
|
return userErr
|
||||||
|
}
|
||||||
|
return r.rsAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
response, joinErr := gomatrixserverlib.PerformSendJoin(ctx, r, joinInput)
|
||||||
|
if joinErr != nil {
|
||||||
|
if !joinErr.Reachable {
|
||||||
|
r.statistics.ForServer(joinErr.ServerName).Failure()
|
||||||
|
} else {
|
||||||
|
r.statistics.ForServer(joinErr.ServerName).Success(statistics.SendDirect)
|
||||||
|
}
|
||||||
|
return joinErr.Err
|
||||||
|
}
|
||||||
|
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
|
||||||
|
if response == nil {
|
||||||
|
return fmt.Errorf("Received nil response from gomatrixserverlib.PerformSendJoin")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to immediately update our list of joined hosts for this room now as we are technically
|
||||||
|
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
|
||||||
|
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
|
||||||
|
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
|
||||||
|
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
|
||||||
|
// The events are trusted now as we performed auth checks above.
|
||||||
|
joinedHosts, err := consumers.JoinedHostsFromEvents(ctx, response.StateSnapshot.GetStateEvents().TrustedEvents(response.JoinEvent.Version(), false), r.rsAPI)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.WithField("room", roomID).Infof("Joined federated room with %d hosts", len(joinedHosts))
|
||||||
|
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
|
||||||
|
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Can I change this to not take respState but instead just take an opaque list of events?
|
||||||
|
if err = roomserverAPI.SendEventWithState(
|
||||||
|
context.Background(),
|
||||||
|
r.rsAPI,
|
||||||
|
user.Domain(),
|
||||||
|
roomserverAPI.KindNew,
|
||||||
|
response.StateSnapshot,
|
||||||
|
&types.HeaderedEvent{PDU: response.JoinEvent},
|
||||||
|
serverName,
|
||||||
|
nil,
|
||||||
|
false,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PerformOutboundPeekRequest implements api.FederationInternalAPI
|
// PerformOutboundPeekRequest implements api.FederationInternalAPI
|
||||||
func (r *FederationInternalAPI) PerformOutboundPeek(
|
func (r *FederationInternalAPI) PerformOutboundPeek(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -433,6 +746,7 @@ func (r *FederationInternalAPI) PerformLeave(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.PerformLeaveRequest,
|
request *api.PerformLeaveRequest,
|
||||||
response *api.PerformLeaveResponse,
|
response *api.PerformLeaveResponse,
|
||||||
|
cryptoIDs bool,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
userID, err := spec.NewUserID(request.UserID, true)
|
userID, err := spec.NewUserID(request.UserID, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -649,6 +963,87 @@ func (r *FederationInternalAPI) SendInviteV3(
|
||||||
return inviteEvent, nil
|
return inviteEvent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MakeInviteCryptoIDs implements api.FederationInternalAPI
|
||||||
|
func (r *FederationInternalAPI) MakeInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.ProtoEvent, invitee spec.UserID, version gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteStrippedState) (gomatrixserverlib.PDU, error) {
|
||||||
|
validRoomID, err := spec.NewRoomID(event.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
verImpl, err := gomatrixserverlib.GetRoomVersion(version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(event.SenderID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO (devon): This should be allowed via a relay. Currently only transactions
|
||||||
|
// can be sent to relays. Would need to extend relays to handle invites.
|
||||||
|
if !r.shouldAttemptDirectFederation(invitee.Domain()) {
|
||||||
|
return nil, fmt.Errorf("relay servers have no meaningful response for invite.")
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"user_id": invitee.String(),
|
||||||
|
"room_id": event.RoomID,
|
||||||
|
"room_version": version,
|
||||||
|
"destination": invitee.Domain(),
|
||||||
|
}).Info("Sending /send_invite")
|
||||||
|
|
||||||
|
inviteReq, err := fclient.NewInviteV3Request(event, version, strippedState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("gomatrixserverlib.NewInviteV3Request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
inviteRes, err := r.federation.MakeInviteCryptoIDs(ctx, inviter.Domain(), invitee.Domain(), inviteReq, invitee)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("r.federation.SendInviteCryptoIDs: failed to send invite: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
inviteEvent, err := verImpl.NewEventFromUntrustedJSON(inviteRes.Event)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("r.federation.SendInviteCryptoIDs failed to decode event response: %w", err)
|
||||||
|
}
|
||||||
|
return inviteEvent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendInviteCryptoIDs implements api.FederationInternalAPI
|
||||||
|
func (r *FederationInternalAPI) SendInviteCryptoIDs(ctx context.Context, event gomatrixserverlib.PDU, invitee spec.UserID, version gomatrixserverlib.RoomVersion) error {
|
||||||
|
validRoomID := event.RoomID()
|
||||||
|
|
||||||
|
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, validRoomID, event.SenderID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO (devon): This should be allowed via a relay. Currently only transactions
|
||||||
|
// can be sent to relays. Would need to extend relays to handle invites.
|
||||||
|
if !r.shouldAttemptDirectFederation(invitee.Domain()) {
|
||||||
|
return fmt.Errorf("relay servers have no meaningful response for invite.")
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"user_id": invitee.String(),
|
||||||
|
"room_id": event.RoomID,
|
||||||
|
"room_version": version,
|
||||||
|
"destination": invitee.Domain(),
|
||||||
|
}).Info("Sending /make_invite")
|
||||||
|
|
||||||
|
inviteReq, err := fclient.NewSendInviteCryptoIDsRequest(event, version)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("gomatrixserverlib.NewInviteV3Request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.federation.SendInviteCryptoIDs(ctx, inviter.Domain(), invitee.Domain(), inviteReq, invitee)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.federation.SendInviteCryptoIDs: failed to send invite: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PerformServersAlive implements api.FederationInternalAPI
|
// PerformServersAlive implements api.FederationInternalAPI
|
||||||
func (r *FederationInternalAPI) PerformBroadcastEDU(
|
func (r *FederationInternalAPI) PerformBroadcastEDU(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
14
go.mod
14
go.mod
|
@ -1,5 +1,9 @@
|
||||||
module github.com/matrix-org/dendrite
|
module github.com/matrix-org/dendrite
|
||||||
|
|
||||||
|
//replace github.com/matrix-org/gomatrixserverlib => ../../gomatrixserverlib/crypto-ids/
|
||||||
|
|
||||||
|
//replace github.com/matrix-org/gomatrixserverlib => /src/gmsl/
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2
|
github.com/Arceliar/ironwood v0.0.0-20221025225125-45b4281814c2
|
||||||
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
|
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
|
||||||
|
@ -22,7 +26,7 @@ require (
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7
|
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||||
github.com/mattn/go-sqlite3 v1.14.17
|
github.com/mattn/go-sqlite3 v1.14.17
|
||||||
|
@ -42,12 +46,12 @@ require (
|
||||||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||||
github.com/yggdrasil-network/yggdrasil-go v0.4.6
|
github.com/yggdrasil-network/yggdrasil-go v0.4.6
|
||||||
go.uber.org/atomic v1.10.0
|
go.uber.org/atomic v1.10.0
|
||||||
golang.org/x/crypto v0.13.0
|
golang.org/x/crypto v0.17.0
|
||||||
golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819
|
golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819
|
||||||
golang.org/x/image v0.5.0
|
golang.org/x/image v0.5.0
|
||||||
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
|
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
|
||||||
golang.org/x/sync v0.3.0
|
golang.org/x/sync v0.3.0
|
||||||
golang.org/x/term v0.12.0
|
golang.org/x/term v0.15.0
|
||||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
gotest.tools/v3 v3.4.0
|
gotest.tools/v3 v3.4.0
|
||||||
|
@ -124,8 +128,8 @@ require (
|
||||||
go.etcd.io/bbolt v1.3.6 // indirect
|
go.etcd.io/bbolt v1.3.6 // indirect
|
||||||
golang.org/x/mod v0.12.0 // indirect
|
golang.org/x/mod v0.12.0 // indirect
|
||||||
golang.org/x/net v0.14.0 // indirect
|
golang.org/x/net v0.14.0 // indirect
|
||||||
golang.org/x/sys v0.12.0 // indirect
|
golang.org/x/sys v0.15.0 // indirect
|
||||||
golang.org/x/text v0.13.0 // indirect
|
golang.org/x/text v0.14.0 // indirect
|
||||||
golang.org/x/time v0.3.0 // indirect
|
golang.org/x/time v0.3.0 // indirect
|
||||||
golang.org/x/tools v0.12.0 // indirect
|
golang.org/x/tools v0.12.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.30.0 // indirect
|
||||||
|
|
20
go.sum
20
go.sum
|
@ -208,8 +208,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4 h1:UuXfC7b29RBDfMdLmggeF3opu3XuGi8bNT9SKZtZc3I=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862 h1:Kuya3qas85ZvVVkuOpemwhgvdJbLojvwvt3xyJTp1dY=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230926165653-79fcff283fc4/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20231219232834-bbfb4a048862/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 h1:6t8kJr8i1/1I5nNttw6nn1ryQJgzVlBmSGgPiiaTdw4=
|
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 h1:6t8kJr8i1/1I5nNttw6nn1ryQJgzVlBmSGgPiiaTdw4=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7/go.mod h1:ReWMS/LoVnOiRAdq9sNUC2NZnd1mZkMNB52QhpTRWjg=
|
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7/go.mod h1:ReWMS/LoVnOiRAdq9sNUC2NZnd1mZkMNB52QhpTRWjg=
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
||||||
|
@ -354,8 +354,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
@ -418,19 +418,19 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU=
|
golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
|
||||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
|
|
@ -85,13 +85,21 @@ func BuildEvent(
|
||||||
}
|
}
|
||||||
builder := verImpl.NewEventBuilderFromProtoEvent(proto)
|
builder := verImpl.NewEventBuilderFromProtoEvent(proto)
|
||||||
|
|
||||||
event, err := builder.Build(
|
var event gomatrixserverlib.PDU
|
||||||
|
if identity.PrivateKey != nil {
|
||||||
|
event, err = builder.Build(
|
||||||
evTime, identity.ServerName, identity.KeyID,
|
evTime, identity.ServerName, identity.KeyID,
|
||||||
identity.PrivateKey,
|
identity.PrivateKey,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
event, err = builder.BuildWithoutSigning(evTime, identity.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &types.HeaderedEvent{PDU: event}, nil
|
return &types.HeaderedEvent{PDU: event}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,6 +92,7 @@ type UserRoomPrivateKeyCreator interface {
|
||||||
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
|
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
|
||||||
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
|
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
|
||||||
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
|
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
|
||||||
|
ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputRoomEventsAPI interface {
|
type InputRoomEventsAPI interface {
|
||||||
|
@ -222,6 +223,7 @@ type ClientRoomserverAPI interface {
|
||||||
DefaultRoomVersionAPI
|
DefaultRoomVersionAPI
|
||||||
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
|
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
|
||||||
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
||||||
|
InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
|
||||||
QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
|
QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
|
||||||
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
|
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
|
||||||
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||||
|
@ -232,6 +234,7 @@ type ClientRoomserverAPI interface {
|
||||||
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
|
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
|
||||||
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
|
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
|
||||||
|
|
||||||
|
PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) ([]gomatrixserverlib.PDU, error)
|
||||||
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
|
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
|
||||||
// PerformRoomUpgrade upgrades a room to a newer version
|
// PerformRoomUpgrade upgrades a room to a newer version
|
||||||
PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
|
PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
|
||||||
|
@ -241,9 +244,12 @@ type ClientRoomserverAPI interface {
|
||||||
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
|
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
|
||||||
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
|
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
|
||||||
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
|
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
|
||||||
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
|
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||||
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
|
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
|
||||||
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
|
PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error
|
||||||
|
PerformSendInviteCryptoIDs(ctx context.Context, req *PerformInviteRequestCryptoIDs) error
|
||||||
|
PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error)
|
||||||
|
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||||
PerformPublish(ctx context.Context, req *PerformPublishRequest) error
|
PerformPublish(ctx context.Context, req *PerformPublishRequest) error
|
||||||
// PerformForget forgets a rooms history for a specific user
|
// PerformForget forgets a rooms history for a specific user
|
||||||
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
|
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
|
||||||
|
@ -305,7 +311,7 @@ type FederationRoomserverAPI interface {
|
||||||
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
||||||
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
|
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
|
||||||
|
|
||||||
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
|
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||||
// Query a given amount (or less) of events prior to a given set of events.
|
// Query a given amount (or less) of events prior to a given set of events.
|
||||||
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,8 @@ type PerformCreateRoomRequest struct {
|
||||||
KeyID gomatrixserverlib.KeyID
|
KeyID gomatrixserverlib.KeyID
|
||||||
PrivateKey ed25519.PrivateKey
|
PrivateKey ed25519.PrivateKey
|
||||||
EventTime time.Time
|
EventTime time.Time
|
||||||
|
|
||||||
|
SenderID string
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformJoinRequest struct {
|
type PerformJoinRequest struct {
|
||||||
|
@ -38,6 +40,23 @@ type PerformJoinRequest struct {
|
||||||
Content map[string]interface{} `json:"content"`
|
Content map[string]interface{} `json:"content"`
|
||||||
ServerNames []spec.ServerName `json:"server_names"`
|
ServerNames []spec.ServerName `json:"server_names"`
|
||||||
Unsigned map[string]interface{} `json:"unsigned"`
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
|
SenderID spec.SenderID
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerformJoinRequestCryptoIDs struct {
|
||||||
|
RoomID string
|
||||||
|
UserID string
|
||||||
|
IsGuest bool
|
||||||
|
Content map[string]interface{}
|
||||||
|
ServerNames []spec.ServerName
|
||||||
|
Unsigned map[string]interface{}
|
||||||
|
JoinEvent gomatrixserverlib.PDU
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerformInviteRequestCryptoIDs struct {
|
||||||
|
RoomID string
|
||||||
|
UserID spec.UserID
|
||||||
|
InviteEvent gomatrixserverlib.PDU
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformLeaveRequest struct {
|
type PerformLeaveRequest struct {
|
||||||
|
|
|
@ -3,6 +3,7 @@ package internal
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -54,6 +55,7 @@ type RoomserverInternalAPI struct {
|
||||||
ServerACLs *acls.ServerACLs
|
ServerACLs *acls.ServerACLs
|
||||||
fsAPI fsAPI.RoomserverFederationAPI
|
fsAPI fsAPI.RoomserverFederationAPI
|
||||||
asAPI asAPI.AppServiceInternalAPI
|
asAPI asAPI.AppServiceInternalAPI
|
||||||
|
usAPI userapi.RoomserverUserAPI
|
||||||
NATSClient *nats.Conn
|
NATSClient *nats.Conn
|
||||||
JetStream nats.JetStreamContext
|
JetStream nats.JetStreamContext
|
||||||
Durable string
|
Durable string
|
||||||
|
@ -214,6 +216,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
||||||
r.Leaver.UserAPI = userAPI
|
r.Leaver.UserAPI = userAPI
|
||||||
r.Inputer.UserAPI = userAPI
|
r.Inputer.UserAPI = userAPI
|
||||||
|
r.usAPI = userAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
||||||
|
@ -251,24 +254,27 @@ func (r *RoomserverInternalAPI) PerformCreateRoom(
|
||||||
func (r *RoomserverInternalAPI) PerformInvite(
|
func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformInviteRequest,
|
req *api.PerformInviteRequest,
|
||||||
) error {
|
cryptoIDs bool,
|
||||||
return r.Inviter.PerformInvite(ctx, req)
|
) (gomatrixserverlib.PDU, error) {
|
||||||
|
return r.Inviter.PerformInvite(ctx, req, cryptoIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) PerformLeave(
|
func (r *RoomserverInternalAPI) PerformLeave(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformLeaveRequest,
|
req *api.PerformLeaveRequest,
|
||||||
res *api.PerformLeaveResponse,
|
res *api.PerformLeaveResponse,
|
||||||
) error {
|
cryptoIDs bool,
|
||||||
outputEvents, err := r.Leaver.PerformLeave(ctx, req, res)
|
) (gomatrixserverlib.PDU, error) {
|
||||||
|
outputEvents, leaveEvent, err := r.Leaver.PerformLeave(ctx, req, res, cryptoIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(outputEvents) == 0 {
|
if len(outputEvents) == 0 {
|
||||||
return nil
|
return leaveEvent, nil
|
||||||
}
|
}
|
||||||
return r.OutputProducer.ProduceRoomEvents(req.RoomID, outputEvents)
|
// TODO: cryptoIDs - what to do with this?
|
||||||
|
return leaveEvent, r.OutputProducer.ProduceRoomEvents(req.RoomID, outputEvents)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) PerformForget(
|
func (r *RoomserverInternalAPI) PerformForget(
|
||||||
|
@ -308,6 +314,10 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
|
return r.usAPI.ClaimOneTimeCryptoID(ctx, roomID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
|
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
|
||||||
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
|
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -330,6 +340,19 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
|
||||||
ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(privKey)),
|
ServerName: spec.ServerName(spec.SenderIDFromPseudoIDKey(privKey)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
if roomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||||
|
sender, err := r.QuerySenderIDForUser(ctx, roomID, senderID)
|
||||||
|
if err != nil {
|
||||||
|
return fclient.SigningIdentity{}, err
|
||||||
|
} else if sender == nil {
|
||||||
|
return fclient.SigningIdentity{}, fmt.Errorf("no sender ID for %s in %s", senderID.String(), roomID.String())
|
||||||
|
}
|
||||||
|
return fclient.SigningIdentity{
|
||||||
|
PrivateKey: nil,
|
||||||
|
KeyID: "ed25519:1",
|
||||||
|
ServerName: spec.ServerName(*sender),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain())
|
identity, err := r.Cfg.Global.SigningIdentityFor(senderID.Domain())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fclient.SigningIdentity{}, err
|
return fclient.SigningIdentity{}, err
|
||||||
|
|
|
@ -445,7 +445,7 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
|
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
|
||||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember {
|
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && event.Type() == spec.MRoomMember {
|
||||||
mapping := gomatrixserverlib.MemberContent{}
|
mapping := gomatrixserverlib.MemberContent{}
|
||||||
if err = json.Unmarshal(event.Content(), &mapping); err != nil {
|
if err = json.Unmarshal(event.Content(), &mapping); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -179,7 +179,7 @@ func (r *Admin) PerformAdminEvacuateUser(
|
||||||
Leaver: *fullUserID,
|
Leaver: *fullUserID,
|
||||||
}
|
}
|
||||||
leaveRes := &api.PerformLeaveResponse{}
|
leaveRes := &api.PerformLeaveResponse{}
|
||||||
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)
|
outputEvents, _, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,417 @@ type Creator struct {
|
||||||
RSAPI api.RoomserverInternalAPI
|
RSAPI api.RoomserverInternalAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerformCreateRoomCryptoIDs handles all the steps necessary to create a new room.
|
||||||
|
// nolint: gocyclo
|
||||||
|
func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) ([]gomatrixserverlib.PDU, error) {
|
||||||
|
verImpl, err := gomatrixserverlib.GetRoomVersion(createRequest.RoomVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, spec.BadJSON("unknown room version")
|
||||||
|
}
|
||||||
|
|
||||||
|
createContent := map[string]interface{}{}
|
||||||
|
if len(createRequest.CreationContent) > 0 {
|
||||||
|
if err = json.Unmarshal(createRequest.CreationContent, &createContent); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed")
|
||||||
|
return nil, spec.BadJSON("invalid create content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
senderID := spec.SenderID(createRequest.SenderID)
|
||||||
|
|
||||||
|
// TODO: cryptoIDs - should we be assigning a room NID yet?
|
||||||
|
_, err = c.DB.AssignRoomNID(ctx, roomID, createRequest.RoomVersion)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("failed to assign roomNID")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||||
|
util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String())
|
||||||
|
bytes := spec.Base64Bytes{}
|
||||||
|
err = bytes.Decode(string(senderID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(bytes) != ed25519.PublicKeySize {
|
||||||
|
return nil, spec.BadJSON("SenderID is not a valid ed25519 public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
keyErr := c.RSAPI.StoreUserRoomPublicKey(ctx, senderID, userID, roomID)
|
||||||
|
if keyErr != nil {
|
||||||
|
util.GetLogger(ctx).WithError(keyErr).Error("StoreUserRoomPublicKey failed")
|
||||||
|
return nil, spec.InternalServerError{Err: keyErr.Error()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
createContent["creator"] = senderID
|
||||||
|
createContent["room_version"] = createRequest.RoomVersion
|
||||||
|
powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
|
||||||
|
joinRuleContent := gomatrixserverlib.JoinRuleContent{
|
||||||
|
JoinRule: spec.Invite,
|
||||||
|
}
|
||||||
|
historyVisibilityContent := gomatrixserverlib.HistoryVisibilityContent{
|
||||||
|
HistoryVisibility: historyVisibilityShared,
|
||||||
|
}
|
||||||
|
|
||||||
|
if createRequest.PowerLevelContentOverride != nil {
|
||||||
|
// Merge powerLevelContentOverride fields by unmarshalling it atop the defaults
|
||||||
|
err = json.Unmarshal(createRequest.PowerLevelContentOverride, &powerLevelContent)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed")
|
||||||
|
return nil, spec.BadJSON("malformed power_level_content_override")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var guestsCanJoin bool
|
||||||
|
switch createRequest.StatePreset {
|
||||||
|
case spec.PresetPrivateChat:
|
||||||
|
joinRuleContent.JoinRule = spec.Invite
|
||||||
|
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
|
||||||
|
guestsCanJoin = true
|
||||||
|
case spec.PresetTrustedPrivateChat:
|
||||||
|
joinRuleContent.JoinRule = spec.Invite
|
||||||
|
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
|
||||||
|
for _, invitee := range createRequest.InvitedUsers {
|
||||||
|
powerLevelContent.Users[invitee] = 100
|
||||||
|
}
|
||||||
|
guestsCanJoin = true
|
||||||
|
case spec.PresetPublicChat:
|
||||||
|
joinRuleContent.JoinRule = spec.Public
|
||||||
|
historyVisibilityContent.HistoryVisibility = historyVisibilityShared
|
||||||
|
}
|
||||||
|
|
||||||
|
createEvent := gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomCreate,
|
||||||
|
Content: createContent,
|
||||||
|
}
|
||||||
|
powerLevelEvent := gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomPowerLevels,
|
||||||
|
Content: powerLevelContent,
|
||||||
|
}
|
||||||
|
joinRuleEvent := gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomJoinRules,
|
||||||
|
Content: joinRuleContent,
|
||||||
|
}
|
||||||
|
historyVisibilityEvent := gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomHistoryVisibility,
|
||||||
|
Content: historyVisibilityContent,
|
||||||
|
}
|
||||||
|
membershipEvent := gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomMember,
|
||||||
|
StateKey: string(senderID),
|
||||||
|
}
|
||||||
|
|
||||||
|
memberContent := gomatrixserverlib.MemberContent{
|
||||||
|
Membership: spec.Join,
|
||||||
|
DisplayName: createRequest.UserDisplayName,
|
||||||
|
AvatarURL: createRequest.UserAvatarURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are creating a room with pseudo IDs, create and sign the MXIDMapping
|
||||||
|
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||||
|
mapping := &gomatrixserverlib.MXIDMapping{
|
||||||
|
UserRoomKey: senderID,
|
||||||
|
UserID: userID.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
identity, idErr := c.Cfg.Matrix.SigningIdentityFor(userID.Domain()) // we MUST use the server signing mxid_mapping
|
||||||
|
if idErr != nil {
|
||||||
|
logrus.WithError(idErr).WithField("domain", userID.Domain()).Error("unable to find signing identity for domain")
|
||||||
|
return nil, spec.InternalServerError{Err: idErr.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign the mapping with the server identity
|
||||||
|
if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil {
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
memberContent.MXIDMapping = mapping
|
||||||
|
}
|
||||||
|
membershipEvent.Content = memberContent
|
||||||
|
|
||||||
|
var nameEvent *gomatrixserverlib.FledglingEvent
|
||||||
|
var topicEvent *gomatrixserverlib.FledglingEvent
|
||||||
|
var guestAccessEvent *gomatrixserverlib.FledglingEvent
|
||||||
|
var aliasEvent *gomatrixserverlib.FledglingEvent
|
||||||
|
|
||||||
|
if createRequest.RoomName != "" {
|
||||||
|
nameEvent = &gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomName,
|
||||||
|
Content: eventutil.NameContent{
|
||||||
|
Name: createRequest.RoomName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if createRequest.Topic != "" {
|
||||||
|
topicEvent = &gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomTopic,
|
||||||
|
Content: eventutil.TopicContent{
|
||||||
|
Topic: createRequest.Topic,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if guestsCanJoin {
|
||||||
|
guestAccessEvent = &gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomGuestAccess,
|
||||||
|
Content: eventutil.GuestAccessContent{
|
||||||
|
GuestAccess: "can_join",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var roomAlias string
|
||||||
|
if createRequest.RoomAliasName != "" {
|
||||||
|
roomAlias = fmt.Sprintf("#%s:%s", createRequest.RoomAliasName, userID.Domain())
|
||||||
|
// check it's free
|
||||||
|
// TODO: This races but is better than nothing
|
||||||
|
hasAliasReq := api.GetRoomIDForAliasRequest{
|
||||||
|
Alias: roomAlias,
|
||||||
|
IncludeAppservices: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
var aliasResp api.GetRoomIDForAliasResponse
|
||||||
|
err = c.RSAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
if aliasResp.RoomID != "" {
|
||||||
|
return nil, spec.RoomInUse("Room ID already exists.")
|
||||||
|
}
|
||||||
|
|
||||||
|
aliasEvent = &gomatrixserverlib.FledglingEvent{
|
||||||
|
Type: spec.MRoomCanonicalAlias,
|
||||||
|
Content: eventutil.CanonicalAlias{
|
||||||
|
Alias: roomAlias,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var initialStateEvents []gomatrixserverlib.FledglingEvent
|
||||||
|
for i := range createRequest.InitialState {
|
||||||
|
if createRequest.InitialState[i].StateKey != "" {
|
||||||
|
initialStateEvents = append(initialStateEvents, createRequest.InitialState[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch createRequest.InitialState[i].Type {
|
||||||
|
case spec.MRoomCreate:
|
||||||
|
continue
|
||||||
|
|
||||||
|
case spec.MRoomPowerLevels:
|
||||||
|
powerLevelEvent = createRequest.InitialState[i]
|
||||||
|
|
||||||
|
case spec.MRoomJoinRules:
|
||||||
|
joinRuleEvent = createRequest.InitialState[i]
|
||||||
|
|
||||||
|
case spec.MRoomHistoryVisibility:
|
||||||
|
historyVisibilityEvent = createRequest.InitialState[i]
|
||||||
|
|
||||||
|
case spec.MRoomGuestAccess:
|
||||||
|
guestAccessEvent = &createRequest.InitialState[i]
|
||||||
|
|
||||||
|
case spec.MRoomName:
|
||||||
|
nameEvent = &createRequest.InitialState[i]
|
||||||
|
|
||||||
|
case spec.MRoomTopic:
|
||||||
|
topicEvent = &createRequest.InitialState[i]
|
||||||
|
|
||||||
|
default:
|
||||||
|
initialStateEvents = append(initialStateEvents, createRequest.InitialState[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// send events into the room in order of:
|
||||||
|
// 1- m.room.create
|
||||||
|
// 2- room creator join member
|
||||||
|
// 3- m.room.power_levels
|
||||||
|
// 4- m.room.join_rules
|
||||||
|
// 5- m.room.history_visibility
|
||||||
|
// 6- m.room.canonical_alias (opt)
|
||||||
|
// 7- m.room.guest_access (opt)
|
||||||
|
// 8- other initial state items
|
||||||
|
// 9- m.room.name (opt)
|
||||||
|
// 10- m.room.topic (opt)
|
||||||
|
// 11- invite events (opt) - with is_direct flag if applicable TODO
|
||||||
|
// 12- 3pid invite events (opt) TODO
|
||||||
|
// This differs from Synapse slightly. Synapse would vary the ordering of 3-7
|
||||||
|
// depending on if those events were in "initial_state" or not. This made it
|
||||||
|
// harder to reason about, hence sticking to a strict static ordering.
|
||||||
|
eventsToMake := []gomatrixserverlib.FledglingEvent{
|
||||||
|
createEvent, membershipEvent, powerLevelEvent, joinRuleEvent, historyVisibilityEvent,
|
||||||
|
}
|
||||||
|
if guestAccessEvent != nil {
|
||||||
|
eventsToMake = append(eventsToMake, *guestAccessEvent)
|
||||||
|
}
|
||||||
|
eventsToMake = append(eventsToMake, initialStateEvents...)
|
||||||
|
if nameEvent != nil {
|
||||||
|
eventsToMake = append(eventsToMake, *nameEvent)
|
||||||
|
}
|
||||||
|
if topicEvent != nil {
|
||||||
|
eventsToMake = append(eventsToMake, *topicEvent)
|
||||||
|
}
|
||||||
|
if aliasEvent != nil {
|
||||||
|
// TODO: bit of a chicken and egg problem here as the alias doesn't exist and cannot until we have made the room.
|
||||||
|
// This means we might fail creating the alias but say the canonical alias is something that doesn't exist.
|
||||||
|
eventsToMake = append(eventsToMake, *aliasEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
var builtEvents []gomatrixserverlib.PDU
|
||||||
|
authEvents := gomatrixserverlib.NewAuthEvents(nil)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
for i, e := range eventsToMake {
|
||||||
|
depth := i + 1 // depth starts at 1
|
||||||
|
|
||||||
|
builder := verImpl.NewEventBuilderFromProtoEvent(&gomatrixserverlib.ProtoEvent{
|
||||||
|
SenderID: string(senderID),
|
||||||
|
RoomID: roomID.String(),
|
||||||
|
Type: e.Type,
|
||||||
|
StateKey: &e.StateKey,
|
||||||
|
Depth: int64(depth),
|
||||||
|
})
|
||||||
|
err = builder.SetContent(e.Content)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
builder.PrevEvents = []string{builtEvents[i-1].EventID()}
|
||||||
|
}
|
||||||
|
var ev gomatrixserverlib.PDU
|
||||||
|
if err = builder.AddAuthEvents(&authEvents); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("AddAuthEvents failed")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
ev, err = builder.BuildWithoutSigning(createRequest.EventTime, userID.Domain())
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("buildEvent failed")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
|
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
|
}); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the event to the list of auth events
|
||||||
|
builtEvents = append(builtEvents, ev)
|
||||||
|
err = authEvents.AddEvent(ev)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(#269): Reserve room alias while we create the room. This stops us
|
||||||
|
// from creating the room but still failing due to the alias having already
|
||||||
|
// been taken.
|
||||||
|
if roomAlias != "" {
|
||||||
|
aliasAlreadyExists, aliasErr := c.RSAPI.SetRoomAlias(ctx, senderID, roomID, roomAlias)
|
||||||
|
if aliasErr != nil {
|
||||||
|
util.GetLogger(ctx).WithError(aliasErr).Error("aliasAPI.SetRoomAlias failed")
|
||||||
|
return nil, spec.InternalServerError{Err: aliasErr.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
if aliasAlreadyExists {
|
||||||
|
return nil, spec.RoomInUse("Room alias already exists.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: cryptoIDs - this shouldn't really be done until the client calls /sendPDUs with these events
|
||||||
|
// But that would require the visibility setting also being passed along
|
||||||
|
if createRequest.Visibility == spec.Public {
|
||||||
|
// expose this room in the published room list
|
||||||
|
if err = c.RSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
|
||||||
|
RoomID: roomID.String(),
|
||||||
|
Visibility: spec.Public,
|
||||||
|
}); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("failed to publish room")
|
||||||
|
return nil, spec.InternalServerError{Err: err.Error()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a direct message then we should invite the participants.
|
||||||
|
//if len(createRequest.InvitedUsers) > 0 {
|
||||||
|
// Build some stripped state for the invite.
|
||||||
|
//var globalStrippedState []gomatrixserverlib.InviteStrippedState
|
||||||
|
//for _, event := range builtEvents {
|
||||||
|
// // Chosen events from the spec:
|
||||||
|
// // https://spec.matrix.org/v1.3/client-server-api/#stripped-state
|
||||||
|
// switch event.Type() {
|
||||||
|
// case spec.MRoomCreate:
|
||||||
|
// fallthrough
|
||||||
|
// case spec.MRoomName:
|
||||||
|
// fallthrough
|
||||||
|
// case spec.MRoomAvatar:
|
||||||
|
// fallthrough
|
||||||
|
// case spec.MRoomTopic:
|
||||||
|
// fallthrough
|
||||||
|
// case spec.MRoomCanonicalAlias:
|
||||||
|
// fallthrough
|
||||||
|
// case spec.MRoomEncryption:
|
||||||
|
// fallthrough
|
||||||
|
// case spec.MRoomMember:
|
||||||
|
// fallthrough
|
||||||
|
// case spec.MRoomJoinRules:
|
||||||
|
// ev := event
|
||||||
|
// globalStrippedState = append(
|
||||||
|
// globalStrippedState,
|
||||||
|
// gomatrixserverlib.NewInviteStrippedState(ev),
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
// Process the invites.
|
||||||
|
//for _, invitee := range createRequest.InvitedUsers {
|
||||||
|
//inviteeUserID, userIDErr := spec.NewUserID(invitee, true)
|
||||||
|
//if userIDErr != nil {
|
||||||
|
// util.GetLogger(ctx).WithError(userIDErr).Error("invalid UserID")
|
||||||
|
// return nil, spec.InternalServerError{}
|
||||||
|
//}
|
||||||
|
|
||||||
|
// TODO: cryptoIDs - these shouldn't be here
|
||||||
|
// instead we should return proto invite events?
|
||||||
|
//err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||||
|
// InviteInput: api.InviteInput{
|
||||||
|
// RoomID: roomID,
|
||||||
|
// Inviter: userID,
|
||||||
|
// Invitee: *inviteeUserID,
|
||||||
|
// DisplayName: createRequest.UserDisplayName,
|
||||||
|
// AvatarURL: createRequest.UserAvatarURL,
|
||||||
|
// Reason: "",
|
||||||
|
// IsDirect: createRequest.IsDirect,
|
||||||
|
// KeyID: createRequest.KeyID,
|
||||||
|
// PrivateKey: createRequest.PrivateKey,
|
||||||
|
// EventTime: createRequest.EventTime,
|
||||||
|
// },
|
||||||
|
// InviteRoomState: globalStrippedState,
|
||||||
|
// SendAsServer: string(userID.Domain()),
|
||||||
|
//})
|
||||||
|
//switch e := err.(type) {
|
||||||
|
//case api.ErrInvalidID:
|
||||||
|
// return nil, spec.Unknown(e.Error())
|
||||||
|
//case api.ErrNotAllowed:
|
||||||
|
// return nil, spec.Forbidden(e.Error())
|
||||||
|
//case nil:
|
||||||
|
//default:
|
||||||
|
// util.GetLogger(ctx).WithError(err).Error("PerformInvite failed")
|
||||||
|
// sentry.CaptureException(err)
|
||||||
|
// return nil, spec.InternalServerError{}
|
||||||
|
//}
|
||||||
|
//}
|
||||||
|
//}
|
||||||
|
|
||||||
|
return builtEvents, nil
|
||||||
|
}
|
||||||
|
|
||||||
// PerformCreateRoom handles all the steps necessary to create a new room.
|
// PerformCreateRoom handles all the steps necessary to create a new room.
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) {
|
func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *api.PerformCreateRoomRequest) (string, *util.JSONResponse) {
|
||||||
|
@ -501,7 +912,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
_, err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||||
InviteInput: api.InviteInput{
|
InviteInput: api.InviteInput{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
Inviter: userID,
|
Inviter: userID,
|
||||||
|
@ -516,7 +927,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
||||||
},
|
},
|
||||||
InviteRoomState: globalStrippedState,
|
InviteRoomState: globalStrippedState,
|
||||||
SendAsServer: string(userID.Domain()),
|
SendAsServer: string(userID.Domain()),
|
||||||
})
|
}, false)
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case api.ErrInvalidID:
|
case api.ErrInvalidID:
|
||||||
return "", &util.JSONResponse{
|
return "", &util.JSONResponse{
|
||||||
|
|
|
@ -125,16 +125,17 @@ func (r *Inviter) ProcessInviteMembership(
|
||||||
func (r *Inviter) PerformInvite(
|
func (r *Inviter) PerformInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformInviteRequest,
|
req *api.PerformInviteRequest,
|
||||||
) error {
|
cryptoIDs bool,
|
||||||
|
) (gomatrixserverlib.PDU, error) {
|
||||||
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
|
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
} else if senderID == nil {
|
} else if senderID == nil {
|
||||||
return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
|
return nil, fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
|
||||||
}
|
}
|
||||||
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
|
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := gomatrixserverlib.ProtoEvent{
|
proto := gomatrixserverlib.ProtoEvent{
|
||||||
|
@ -152,20 +153,20 @@ func (r *Inviter) PerformInvite(
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = proto.SetContent(content); err != nil {
|
if err = proto.SetContent(content); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
|
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
|
||||||
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
|
return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
|
||||||
}
|
}
|
||||||
|
|
||||||
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
|
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
|
||||||
|
|
||||||
signingKey := req.InviteInput.PrivateKey
|
signingKey := req.InviteInput.PrivateKey
|
||||||
if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
if !cryptoIDs && info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||||
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
|
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,6 +223,10 @@ func (r *Inviter) PerformInvite(
|
||||||
}
|
}
|
||||||
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
||||||
},
|
},
|
||||||
|
CryptoIDs: cryptoIDs,
|
||||||
|
ClaimSenderID: func(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
|
return r.RSAPI.ClaimOneTimeSenderIDForUser(ctx, roomID, userID)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
||||||
|
@ -229,12 +234,14 @@ func (r *Inviter) PerformInvite(
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case spec.MatrixError:
|
case spec.MatrixError:
|
||||||
if e.ErrCode == spec.ErrorForbidden {
|
if e.ErrCode == spec.ErrorForbidden {
|
||||||
return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
|
return nil, api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var response gomatrixserverlib.PDU
|
||||||
|
if !cryptoIDs {
|
||||||
// Send the invite event to the roomserver input stream. This will
|
// Send the invite event to the roomserver input stream. This will
|
||||||
// notify existing users in the room about the invite, update the
|
// notify existing users in the room about the invite, update the
|
||||||
// membership table and ensure that the event is ready and available
|
// membership table and ensure that the event is ready and available
|
||||||
|
@ -254,8 +261,11 @@ func (r *Inviter) PerformInvite(
|
||||||
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
|
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
|
||||||
if err := inputRes.Err(); err != nil {
|
if err := inputRes.Err(); err != nil {
|
||||||
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
|
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
|
||||||
return api.ErrNotAllowed{Err: err}
|
return nil, api.ErrNotAllowed{Err: err}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response = inviteEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -406,6 +406,380 @@ func (r *Joiner) performJoinRoomByID(
|
||||||
return req.RoomIDOrAlias, userDomain, nil
|
return req.RoomIDOrAlias, userDomain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Joiner) PerformSendJoinCryptoIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
req *rsAPI.PerformJoinRequestCryptoIDs,
|
||||||
|
) error {
|
||||||
|
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||||
|
"room_id": req.RoomID,
|
||||||
|
"user_id": req.UserID,
|
||||||
|
"servers": req.ServerNames,
|
||||||
|
})
|
||||||
|
logger.Info("performing send join")
|
||||||
|
res := fsAPI.PerformJoinResponse{}
|
||||||
|
r.FSAPI.PerformSendJoin(ctx, &fsAPI.PerformSendJoinRequestCryptoIDs{
|
||||||
|
RoomID: req.RoomID,
|
||||||
|
UserID: req.UserID,
|
||||||
|
ServerNames: req.ServerNames,
|
||||||
|
Unsigned: req.Unsigned,
|
||||||
|
Event: req.JoinEvent,
|
||||||
|
}, &res)
|
||||||
|
if res.LastError != nil {
|
||||||
|
return res.LastError
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Joiner) PerformSendInviteCryptoIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
req *rsAPI.PerformInviteRequestCryptoIDs,
|
||||||
|
) error {
|
||||||
|
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||||
|
"room_id": req.RoomID,
|
||||||
|
"user_id": req.UserID,
|
||||||
|
})
|
||||||
|
logger.Info("performing send invite")
|
||||||
|
err := r.FSAPI.SendInviteCryptoIDs(ctx, req.InviteEvent, req.UserID, req.InviteEvent.Version())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformJoin handles joining matrix rooms, including over federation by talking to the federationapi.
|
||||||
|
func (r *Joiner) PerformJoinCryptoIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
req *rsAPI.PerformJoinRequest,
|
||||||
|
) (joinEvent gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error) {
|
||||||
|
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||||
|
"room_id": req.RoomIDOrAlias,
|
||||||
|
"user_id": req.UserID,
|
||||||
|
"servers": req.ServerNames,
|
||||||
|
})
|
||||||
|
logger.Info("User requested to room join")
|
||||||
|
join, roomID, version, serverName, err := r.makeJoinEvent(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Error("Failed to make join room event")
|
||||||
|
sentry.CaptureException(err)
|
||||||
|
return nil, "", "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return join, roomID, version, serverName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Joiner) makeJoinEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
req *rsAPI.PerformJoinRequest,
|
||||||
|
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||||
|
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)}
|
||||||
|
}
|
||||||
|
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||||
|
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user %q does not belong to this homeserver", req.UserID)}
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(req.RoomIDOrAlias, "!") {
|
||||||
|
return r.performJoinRoomByIDCryptoIDs(ctx, req)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(req.RoomIDOrAlias, "#") {
|
||||||
|
return r.performJoinRoomByAliasCryptoIDs(ctx, req)
|
||||||
|
}
|
||||||
|
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID or alias %q is invalid", req.RoomIDOrAlias)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Joiner) performJoinRoomByAliasCryptoIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
req *rsAPI.PerformJoinRequest,
|
||||||
|
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||||
|
// Get the domain part of the room alias.
|
||||||
|
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", "", fmt.Errorf("alias %q is not in the correct format", req.RoomIDOrAlias)
|
||||||
|
}
|
||||||
|
req.ServerNames = append(req.ServerNames, domain)
|
||||||
|
|
||||||
|
// Check if this alias matches our own server configuration. If it
|
||||||
|
// doesn't then we'll need to try a federated join.
|
||||||
|
var roomID string
|
||||||
|
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||||
|
// The alias isn't owned by us, so we will need to try joining using
|
||||||
|
// a remote server.
|
||||||
|
dirReq := fsAPI.PerformDirectoryLookupRequest{
|
||||||
|
RoomAlias: req.RoomIDOrAlias, // the room alias to lookup
|
||||||
|
ServerName: domain, // the server to ask
|
||||||
|
}
|
||||||
|
dirRes := fsAPI.PerformDirectoryLookupResponse{}
|
||||||
|
err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
|
||||||
|
return nil, "", "", "", fmt.Errorf("looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
|
||||||
|
}
|
||||||
|
roomID = dirRes.RoomID
|
||||||
|
req.ServerNames = append(req.ServerNames, dirRes.ServerNames...)
|
||||||
|
} else {
|
||||||
|
var getRoomReq = rsAPI.GetRoomIDForAliasRequest{
|
||||||
|
Alias: req.RoomIDOrAlias,
|
||||||
|
IncludeAppservices: true,
|
||||||
|
}
|
||||||
|
var getRoomRes = rsAPI.GetRoomIDForAliasResponse{}
|
||||||
|
// Otherwise, look up if we know this room alias locally.
|
||||||
|
err = r.RSAPI.GetRoomIDForAlias(ctx, &getRoomReq, &getRoomRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", "", fmt.Errorf("lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
|
||||||
|
}
|
||||||
|
roomID = getRoomRes.RoomID
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the room ID is empty then we failed to look up the alias.
|
||||||
|
if roomID == "" {
|
||||||
|
return nil, "", "", "", fmt.Errorf("alias %q not found", req.RoomIDOrAlias)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we do, then pluck out the room ID and continue the join.
|
||||||
|
req.RoomIDOrAlias = roomID
|
||||||
|
return r.performJoinRoomByIDCryptoIDs(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Break this function up a bit & move to GMSL
|
||||||
|
// nolint:gocyclo
|
||||||
|
func (r *Joiner) performJoinRoomByIDCryptoIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
req *rsAPI.PerformJoinRequest,
|
||||||
|
) (gomatrixserverlib.PDU, string, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||||
|
// The original client request ?server_name=... may include this HS so filter that out so we
|
||||||
|
// don't attempt to make_join with ourselves
|
||||||
|
for i := 0; i < len(req.ServerNames); i++ {
|
||||||
|
if r.Cfg.Matrix.IsLocalServerName(req.ServerNames[i]) {
|
||||||
|
// delete this entry
|
||||||
|
req.ServerNames = append(req.ServerNames[:i], req.ServerNames[i+1:]...)
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the domain part of the room ID.
|
||||||
|
roomID, err := spec.NewRoomID(req.RoomIDOrAlias)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the server name in the room ID isn't ours then it's a
|
||||||
|
// possible candidate for finding the room via federation. Add
|
||||||
|
// it to the list of servers to try.
|
||||||
|
if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
|
||||||
|
req.ServerNames = append(req.ServerNames, roomID.Domain())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force a federated join if we aren't in the room and we've been
|
||||||
|
// given some server names to try joining by.
|
||||||
|
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
|
||||||
|
RoomID: req.RoomIDOrAlias,
|
||||||
|
}
|
||||||
|
inRoomRes := &rsAPI.QueryServerJoinedToRoomResponse{}
|
||||||
|
if err = r.Queryer.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
|
||||||
|
return nil, "", "", "", fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
|
||||||
|
}
|
||||||
|
serverInRoom := inRoomRes.IsInRoom
|
||||||
|
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
|
||||||
|
|
||||||
|
userID, err := spec.NewUserID(req.UserID, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: CryptoIDs - what is provided & calculated senderIDs don't match?
|
||||||
|
|
||||||
|
// Look up the room NID for the supplied room ID.
|
||||||
|
var senderID spec.SenderID
|
||||||
|
checkInvitePending := false
|
||||||
|
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
|
||||||
|
if err == nil && info != nil {
|
||||||
|
switch info.RoomVersion {
|
||||||
|
case gomatrixserverlib.RoomVersionCryptoIDs:
|
||||||
|
senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||||
|
if queryErr == nil {
|
||||||
|
checkInvitePending = true
|
||||||
|
}
|
||||||
|
if senderIDPtr == nil {
|
||||||
|
senderID = req.SenderID
|
||||||
|
} else {
|
||||||
|
senderID = *senderIDPtr
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
checkInvitePending = true
|
||||||
|
senderID = spec.SenderID(userID.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force a federated join if we're dealing with a pending invite
|
||||||
|
// and we aren't in the room.
|
||||||
|
if checkInvitePending {
|
||||||
|
isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
|
||||||
|
if inviteErr == nil && !serverInRoom && isInvitePending {
|
||||||
|
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
|
||||||
|
if queryErr != nil {
|
||||||
|
return nil, "", "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we were invited by someone from another server then we can
|
||||||
|
// assume they are in the room so we can join via them.
|
||||||
|
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
|
||||||
|
req.ServerNames = append(req.ServerNames, inviter.Domain())
|
||||||
|
forceFederatedJoin = true
|
||||||
|
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
||||||
|
// only set unsigned if we've got a content.membership, which we _should_
|
||||||
|
if memberEvent.Get("content.membership").Exists() {
|
||||||
|
req.Unsigned = map[string]interface{}{
|
||||||
|
"prev_sender": memberEvent.Get("sender").Str,
|
||||||
|
"prev_content": map[string]interface{}{
|
||||||
|
"is_direct": memberEvent.Get("content.is_direct").Bool(),
|
||||||
|
"membership": memberEvent.Get("content.membership").Str,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
|
||||||
|
if req.IsGuest {
|
||||||
|
var guestAccessEvent *types.HeaderedEvent
|
||||||
|
guestAccess := "forbidden"
|
||||||
|
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, spec.MRoomGuestAccess, "")
|
||||||
|
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
|
||||||
|
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
|
||||||
|
}
|
||||||
|
if guestAccessEvent != nil {
|
||||||
|
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
|
||||||
|
// is present on the room and has the guest_access value can_join.
|
||||||
|
if guestAccess != "can_join" {
|
||||||
|
return nil, "", "", "", rsAPI.ErrNotAllowed{Err: fmt.Errorf("guest access is forbidden")}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we should do a forced federated join then do that.
|
||||||
|
if forceFederatedJoin {
|
||||||
|
joinEvent, version, serverName, federatedJoinErr := r.performFederatedMakeJoinByIDCryptoIDs(ctx, req)
|
||||||
|
return joinEvent, req.RoomIDOrAlias, version, serverName, federatedJoinErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to construct an actual join event from the template.
|
||||||
|
// If this succeeds then it is a sign that the room already exists
|
||||||
|
// locally on the homeserver.
|
||||||
|
// TODO: Check what happens if the room exists on the server
|
||||||
|
// but everyone has since left. I suspect it does the wrong thing.
|
||||||
|
|
||||||
|
var buildRes rsAPI.QueryLatestEventsAndStateResponse
|
||||||
|
identity := r.Cfg.Matrix.SigningIdentity
|
||||||
|
|
||||||
|
// at this point we know we have an existing room
|
||||||
|
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||||
|
mapping := &gomatrixserverlib.MXIDMapping{
|
||||||
|
UserRoomKey: senderID,
|
||||||
|
UserID: userID.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign the mapping with the server identity
|
||||||
|
if err = mapping.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil {
|
||||||
|
return nil, "", "", "", err
|
||||||
|
}
|
||||||
|
req.Content["mxid_mapping"] = mapping
|
||||||
|
|
||||||
|
// sign the event with the pseudo ID key
|
||||||
|
identity = fclient.SigningIdentity{
|
||||||
|
ServerName: spec.ServerName(senderID),
|
||||||
|
KeyID: "ed25519:1",
|
||||||
|
PrivateKey: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
senderIDString := string(senderID)
|
||||||
|
|
||||||
|
// Prepare the template for the join event.
|
||||||
|
proto := gomatrixserverlib.ProtoEvent{
|
||||||
|
Type: spec.MRoomMember,
|
||||||
|
SenderID: senderIDString,
|
||||||
|
StateKey: &senderIDString,
|
||||||
|
RoomID: req.RoomIDOrAlias,
|
||||||
|
Redacts: "",
|
||||||
|
}
|
||||||
|
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
||||||
|
return nil, "", "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is possible for the request to include some "content" for the
|
||||||
|
// event. We'll always overwrite the "membership" key, but the rest,
|
||||||
|
// like "display_name" or "avatar_url", will be kept if supplied.
|
||||||
|
if req.Content == nil {
|
||||||
|
req.Content = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
req.Content["membership"] = spec.Join
|
||||||
|
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
|
||||||
|
return nil, "", "", "", aerr
|
||||||
|
} else if authorisedVia != "" {
|
||||||
|
req.Content["join_authorised_via_users_server"] = authorisedVia
|
||||||
|
}
|
||||||
|
if err = proto.SetContent(req.Content); err != nil {
|
||||||
|
return nil, "", "", "", fmt.Errorf("eb.SetContent: %w", err)
|
||||||
|
}
|
||||||
|
joinEvent, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes)
|
||||||
|
|
||||||
|
switch err.(type) {
|
||||||
|
case nil:
|
||||||
|
// Do nothing
|
||||||
|
|
||||||
|
case eventutil.ErrRoomNoExists:
|
||||||
|
// The room doesn't exist locally. If the room ID looks like it should
|
||||||
|
// be ours then this probably means that we've nuked our database at
|
||||||
|
// some point.
|
||||||
|
if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
|
||||||
|
// If there are no more server names to try then give up here.
|
||||||
|
// Otherwise we'll try a federated join as normal, since it's quite
|
||||||
|
// possible that the room still exists on other servers.
|
||||||
|
if len(req.ServerNames) == 0 {
|
||||||
|
return nil, "", "", "", eventutil.ErrRoomNoExists{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform a federated room join.
|
||||||
|
joinEvent, version, serverName, federatedJoinErr := r.performFederatedMakeJoinByIDCryptoIDs(ctx, req)
|
||||||
|
return joinEvent, req.RoomIDOrAlias, version, serverName, federatedJoinErr
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Something else went wrong.
|
||||||
|
return nil, "", "", "", fmt.Errorf("error joining local room: %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// By this point, if req.RoomIDOrAlias contained an alias, then
|
||||||
|
// it will have been overwritten with a room ID by performJoinRoomByAlias.
|
||||||
|
// We should now include this in the response so that the CS API can
|
||||||
|
// return the right room ID.
|
||||||
|
return joinEvent, req.RoomIDOrAlias, inRoomRes.RoomVersion, userID.Domain(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Joiner) performFederatedMakeJoinByIDCryptoIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
req *rsAPI.PerformJoinRequest,
|
||||||
|
) (gomatrixserverlib.PDU, gomatrixserverlib.RoomVersion, spec.ServerName, error) {
|
||||||
|
// Try joining by all of the supplied server names.
|
||||||
|
fedReq := fsAPI.PerformJoinRequest{
|
||||||
|
RoomID: req.RoomIDOrAlias, // the room ID to try and join
|
||||||
|
UserID: req.UserID, // the user ID joining the room
|
||||||
|
ServerNames: req.ServerNames, // the server to try joining with
|
||||||
|
Content: req.Content, // the membership event content
|
||||||
|
Unsigned: req.Unsigned, // the unsigned event content, if any
|
||||||
|
}
|
||||||
|
joinEvent, version, serverName, err := r.FSAPI.PerformMakeJoin(ctx, &fedReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", err
|
||||||
|
}
|
||||||
|
return joinEvent, version, serverName, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Joiner) performFederatedJoinRoomByID(
|
func (r *Joiner) performFederatedJoinRoomByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *rsAPI.PerformJoinRequest,
|
req *rsAPI.PerformJoinRequest,
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -52,9 +53,10 @@ func (r *Leaver) PerformLeave(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformLeaveRequest,
|
req *api.PerformLeaveRequest,
|
||||||
res *api.PerformLeaveResponse,
|
res *api.PerformLeaveResponse,
|
||||||
) ([]api.OutputEvent, error) {
|
cryptoIDs bool,
|
||||||
|
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
|
||||||
if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
|
if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
|
||||||
return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
|
return nil, nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
|
||||||
}
|
}
|
||||||
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
|
||||||
"room_id": req.RoomID,
|
"room_id": req.RoomID,
|
||||||
|
@ -62,15 +64,15 @@ func (r *Leaver) PerformLeave(
|
||||||
})
|
})
|
||||||
logger.Info("User requested to leave join")
|
logger.Info("User requested to leave join")
|
||||||
if strings.HasPrefix(req.RoomID, "!") {
|
if strings.HasPrefix(req.RoomID, "!") {
|
||||||
output, err := r.performLeaveRoomByID(context.Background(), req, res)
|
output, event, err := r.performLeaveRoomByID(context.Background(), req, res, cryptoIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("Failed to leave room")
|
logger.WithError(err).Error("Failed to leave room")
|
||||||
} else {
|
} else {
|
||||||
logger.Info("User left room successfully")
|
logger.Info("User left room successfully")
|
||||||
}
|
}
|
||||||
return output, err
|
return output, event, err
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("room ID %q is invalid", req.RoomID)
|
return nil, nil, fmt.Errorf("room ID %q is invalid", req.RoomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
|
@ -78,14 +80,15 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformLeaveRequest,
|
req *api.PerformLeaveRequest,
|
||||||
res *api.PerformLeaveResponse, // nolint:unparam
|
res *api.PerformLeaveResponse, // nolint:unparam
|
||||||
) ([]api.OutputEvent, error) {
|
cryptoIDs bool,
|
||||||
|
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
|
||||||
roomID, err := spec.NewRoomID(req.RoomID)
|
roomID, err := spec.NewRoomID(req.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
|
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
|
||||||
if err != nil || leaver == nil {
|
if err != nil || leaver == nil {
|
||||||
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
return nil, nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's an invite outstanding for the room then respond to
|
// If there's an invite outstanding for the room then respond to
|
||||||
|
@ -94,7 +97,7 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
if err == nil && isInvitePending {
|
if err == nil && isInvitePending {
|
||||||
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
|
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
|
||||||
if serr != nil {
|
if serr != nil {
|
||||||
return nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr)
|
return nil, nil, fmt.Errorf("failed looking up userID for sender %q: %w", senderUser, serr)
|
||||||
}
|
}
|
||||||
|
|
||||||
var domain spec.ServerName
|
var domain spec.ServerName
|
||||||
|
@ -107,7 +110,7 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
domain = sender.Domain()
|
domain = sender.Domain()
|
||||||
}
|
}
|
||||||
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
if !r.Cfg.Matrix.IsLocalServerName(domain) {
|
||||||
return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver)
|
return r.performFederatedRejectInvite(ctx, req, res, domain, eventID, *leaver, cryptoIDs)
|
||||||
}
|
}
|
||||||
// check that this is not a "server notice room"
|
// check that this is not a "server notice room"
|
||||||
accData := &userapi.QueryAccountDataResponse{}
|
accData := &userapi.QueryAccountDataResponse{}
|
||||||
|
@ -116,7 +119,7 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
RoomID: req.RoomID,
|
RoomID: req.RoomID,
|
||||||
DataType: "m.tag",
|
DataType: "m.tag",
|
||||||
}, accData); err != nil {
|
}, accData); err != nil {
|
||||||
return nil, fmt.Errorf("unable to query account data: %w", err)
|
return nil, nil, fmt.Errorf("unable to query account data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if roomData, ok := accData.RoomAccountData[req.RoomID]; ok {
|
if roomData, ok := accData.RoomAccountData[req.RoomID]; ok {
|
||||||
|
@ -124,13 +127,13 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
if ok {
|
if ok {
|
||||||
tags := gomatrix.TagContent{}
|
tags := gomatrix.TagContent{}
|
||||||
if err = json.Unmarshal(tagData, &tags); err != nil {
|
if err = json.Unmarshal(tagData, &tags); err != nil {
|
||||||
return nil, fmt.Errorf("unable to unmarshal tag content")
|
return nil, nil, fmt.Errorf("unable to unmarshal tag content")
|
||||||
}
|
}
|
||||||
if _, ok = tags.Tags["m.server_notice"]; ok {
|
if _, ok = tags.Tags["m.server_notice"]; ok {
|
||||||
// mimic the returned values from Synapse
|
// mimic the returned values from Synapse
|
||||||
res.Message = "You cannot reject this invite"
|
res.Message = "You cannot reject this invite"
|
||||||
res.Code = 403
|
res.Code = 403
|
||||||
return nil, spec.LeaveServerNoticeError()
|
return nil, nil, spec.LeaveServerNoticeError()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -149,22 +152,22 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
}
|
}
|
||||||
latestRes := api.QueryLatestEventsAndStateResponse{}
|
latestRes := api.QueryLatestEventsAndStateResponse{}
|
||||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
|
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if !latestRes.RoomExists {
|
if !latestRes.RoomExists {
|
||||||
return nil, fmt.Errorf("room %q does not exist", req.RoomID)
|
return nil, nil, fmt.Errorf("room %q does not exist", req.RoomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now let's see if the user is in the room.
|
// Now let's see if the user is in the room.
|
||||||
if len(latestRes.StateEvents) == 0 {
|
if len(latestRes.StateEvents) == 0 {
|
||||||
return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
|
return nil, nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
|
||||||
}
|
}
|
||||||
membership, err := latestRes.StateEvents[0].Membership()
|
membership, err := latestRes.StateEvents[0].Membership()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting membership: %w", err)
|
return nil, nil, fmt.Errorf("error getting membership: %w", err)
|
||||||
}
|
}
|
||||||
if membership != spec.Join && membership != spec.Invite {
|
if membership != spec.Join && membership != spec.Invite {
|
||||||
return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
|
return nil, nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the template for the leave event.
|
// Prepare the template for the leave event.
|
||||||
|
@ -177,10 +180,10 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
Redacts: "",
|
Redacts: "",
|
||||||
}
|
}
|
||||||
if err = proto.SetContent(map[string]interface{}{"membership": "leave"}); err != nil {
|
if err = proto.SetContent(map[string]interface{}{"membership": "leave"}); err != nil {
|
||||||
return nil, fmt.Errorf("eb.SetContent: %w", err)
|
return nil, nil, fmt.Errorf("eb.SetContent: %w", err)
|
||||||
}
|
}
|
||||||
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
||||||
return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
|
return nil, nil, fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We know that the user is in the room at this point so let's build
|
// We know that the user is in the room at this point so let's build
|
||||||
|
@ -190,19 +193,29 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
|
|
||||||
validRoomID, err := spec.NewRoomID(req.RoomID)
|
validRoomID, err := spec.NewRoomID(req.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var buildRes rsAPI.QueryLatestEventsAndStateResponse
|
var buildRes rsAPI.QueryLatestEventsAndStateResponse
|
||||||
identity, err := r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver)
|
var identity fclient.SigningIdentity
|
||||||
|
if !cryptoIDs {
|
||||||
|
identity, err = r.RSAPI.SigningIdentityFor(ctx, *validRoomID, req.Leaver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("SigningIdentityFor: %w", err)
|
return nil, nil, fmt.Errorf("SigningIdentityFor: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
identity = fclient.SigningIdentity{
|
||||||
|
ServerName: spec.ServerName(*leaver),
|
||||||
|
KeyID: "ed25519:1",
|
||||||
|
PrivateKey: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes)
|
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, &identity, time.Now(), r.RSAPI, &buildRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err)
|
return nil, nil, fmt.Errorf("eventutil.QueryAndBuildEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cryptoIDs {
|
||||||
// Give our leave event to the roomserver input stream. The
|
// Give our leave event to the roomserver input stream. The
|
||||||
// roomserver will process the membership change and notify
|
// roomserver will process the membership change and notify
|
||||||
// downstream automatically.
|
// downstream automatically.
|
||||||
|
@ -219,10 +232,11 @@ func (r *Leaver) performLeaveRoomByID(
|
||||||
inputRes := api.InputRoomEventsResponse{}
|
inputRes := api.InputRoomEventsResponse{}
|
||||||
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
|
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
|
||||||
if err = inputRes.Err(); err != nil {
|
if err = inputRes.Err(); err != nil {
|
||||||
return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
|
return nil, nil, fmt.Errorf("r.InputRoomEvents: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, event, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Leaver) performFederatedRejectInvite(
|
func (r *Leaver) performFederatedRejectInvite(
|
||||||
|
@ -231,7 +245,8 @@ func (r *Leaver) performFederatedRejectInvite(
|
||||||
res *api.PerformLeaveResponse, // nolint:unparam
|
res *api.PerformLeaveResponse, // nolint:unparam
|
||||||
inviteDomain spec.ServerName, eventID string,
|
inviteDomain spec.ServerName, eventID string,
|
||||||
leaver spec.SenderID,
|
leaver spec.SenderID,
|
||||||
) ([]api.OutputEvent, error) {
|
cryptoIDs bool,
|
||||||
|
) ([]api.OutputEvent, gomatrixserverlib.PDU, error) {
|
||||||
// Ask the federation sender to perform a federated leave for us.
|
// Ask the federation sender to perform a federated leave for us.
|
||||||
leaveReq := fsAPI.PerformLeaveRequest{
|
leaveReq := fsAPI.PerformLeaveRequest{
|
||||||
RoomID: req.RoomID,
|
RoomID: req.RoomID,
|
||||||
|
@ -239,7 +254,7 @@ func (r *Leaver) performFederatedRejectInvite(
|
||||||
ServerNames: []spec.ServerName{inviteDomain},
|
ServerNames: []spec.ServerName{inviteDomain},
|
||||||
}
|
}
|
||||||
leaveRes := fsAPI.PerformLeaveResponse{}
|
leaveRes := fsAPI.PerformLeaveResponse{}
|
||||||
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
|
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes, cryptoIDs); err != nil {
|
||||||
// failures in PerformLeave should NEVER stop us from telling other components like the
|
// failures in PerformLeave should NEVER stop us from telling other components like the
|
||||||
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
|
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
|
||||||
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
|
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
|
||||||
|
@ -279,5 +294,5 @@ func (r *Leaver) performFederatedRejectInvite(
|
||||||
TargetSenderID: leaver,
|
TargetSenderID: leaver,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1044,7 +1044,7 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch version {
|
switch version {
|
||||||
case gomatrixserverlib.RoomVersionPseudoIDs:
|
case gomatrixserverlib.RoomVersionPseudoIDs, gomatrixserverlib.RoomVersionCryptoIDs:
|
||||||
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
|
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -16,7 +16,8 @@ type RoomServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RoomServer) Defaults(opts DefaultOpts) {
|
func (c *RoomServer) Defaults(opts DefaultOpts) {
|
||||||
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
|
//c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
|
||||||
|
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionCryptoIDs
|
||||||
if opts.Generate {
|
if opts.Generate {
|
||||||
if !opts.SingleDatabase {
|
if !opts.SingleDatabase {
|
||||||
c.Database.ConnectionString = "file:roomserver.db"
|
c.Database.ConnectionString = "file:roomserver.db"
|
||||||
|
|
|
@ -46,6 +46,16 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OTCryptoIDCounts adds one-time pseudoID counts to the /sync response
|
||||||
|
func OTCryptoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
|
||||||
|
count, err := keyAPI.QueryOneTimeCryptoIDs(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.OTCryptoIDsCount = count.KeyCount
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||||
// be already filled in with join/leave information.
|
// be already filled in with join/leave information.
|
||||||
|
|
|
@ -50,7 +50,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
|
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
func (a *mockKeyAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
|
||||||
|
return userapi.OneTimeCryptoIDsCount{}, nil
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
|
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -38,7 +38,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
|
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
|
||||||
|
return from
|
||||||
|
}
|
||||||
|
err = internal.OTCryptoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
|
||||||
|
if err != nil {
|
||||||
|
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
|
||||||
return from
|
return from
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -280,6 +280,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
if err != nil && err != context.Canceled {
|
if err != nil && err != context.Canceled {
|
||||||
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
|
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
|
||||||
}
|
}
|
||||||
|
err = internal.OTCryptoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
|
||||||
|
if err != nil && err != context.Canceled {
|
||||||
|
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
|
@ -112,6 +112,10 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *syncUserAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
|
||||||
|
return userapi.OneTimeCryptoIDsCount{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -153,7 +153,7 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDFor
|
||||||
// TODO: Set Signatures & Hashes fields
|
// TODO: Set Signatures & Hashes fields
|
||||||
}
|
}
|
||||||
|
|
||||||
if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
|
if format != FormatSyncFederation && (se.Version() == gomatrixserverlib.RoomVersionPseudoIDs || se.Version() == gomatrixserverlib.RoomVersionCryptoIDs) {
|
||||||
err := updatePseudoIDs(&ce, se, userIDForSender, format)
|
err := updatePseudoIDs(&ce, se, userIDForSender, format)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -304,7 +304,7 @@ func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomS
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation {
|
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != FormatSyncFederation {
|
||||||
for i, ev := range inviteStateEvents {
|
for i, ev := range inviteStateEvents {
|
||||||
userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
|
userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
|
||||||
if userIDErr != nil {
|
if userIDErr != nil {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -365,6 +366,7 @@ type Response struct {
|
||||||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||||
|
OTCryptoIDsCount map[string]int `json:"one_time_cryptoids_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Response) MarshalJSON() ([]byte, error) {
|
func (r Response) MarshalJSON() ([]byte, error) {
|
||||||
|
@ -427,6 +429,7 @@ func NewResponse() *Response {
|
||||||
res.DeviceLists = &DeviceLists{}
|
res.DeviceLists = &DeviceLists{}
|
||||||
res.ToDevice = &ToDeviceResponse{}
|
res.ToDevice = &ToDeviceResponse{}
|
||||||
res.DeviceListsOTKCount = map[string]int{}
|
res.DeviceListsOTKCount = map[string]int{}
|
||||||
|
res.OTCryptoIDsCount = map[string]int{}
|
||||||
|
|
||||||
return &res
|
return &res
|
||||||
}
|
}
|
||||||
|
@ -530,6 +533,7 @@ type InviteResponse struct {
|
||||||
InviteState struct {
|
InviteState struct {
|
||||||
Events []json.RawMessage `json:"events"`
|
Events []json.RawMessage `json:"events"`
|
||||||
} `json:"invite_state"`
|
} `json:"invite_state"`
|
||||||
|
OneTimeCryptoID string `json:"one_time_cryptoid,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInviteResponse creates an empty response with initialised arrays.
|
// NewInviteResponse creates an empty response with initialised arrays.
|
||||||
|
@ -537,11 +541,17 @@ func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *t
|
||||||
res := InviteResponse{}
|
res := InviteResponse{}
|
||||||
res.InviteState.Events = []json.RawMessage{}
|
res.InviteState.Events = []json.RawMessage{}
|
||||||
|
|
||||||
|
logrus.Infof("Room version: %s", event.Version())
|
||||||
|
if event.Version() == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||||
|
logrus.Infof("Setting invite cryptoID to %s", *event.PDU.StateKey())
|
||||||
|
res.OneTimeCryptoID = *event.PDU.StateKey()
|
||||||
|
}
|
||||||
|
|
||||||
// First see if there's invite_room_state in the unsigned key of the invite.
|
// First see if there's invite_room_state in the unsigned key of the invite.
|
||||||
// If there is then unmarshal it into the response. This will contain the
|
// If there is then unmarshal it into the response. This will contain the
|
||||||
// partial room state such as join rules, room name etc.
|
// partial room state such as join rules, room name etc.
|
||||||
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
|
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
|
||||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation {
|
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != synctypes.FormatSyncFederation {
|
||||||
updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}, inviteRoomState, event.PDU, event.RoomID(), eventFormat)
|
}, inviteRoomState, event.PDU, event.RoomID(), eventFormat)
|
||||||
|
|
|
@ -51,6 +51,7 @@ type AppserviceUserAPI interface {
|
||||||
type RoomserverUserAPI interface {
|
type RoomserverUserAPI interface {
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||||
|
ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// api functions required by the media api
|
// api functions required by the media api
|
||||||
|
@ -669,6 +670,7 @@ type UploadDeviceKeysAPI interface {
|
||||||
type SyncKeyAPI interface {
|
type SyncKeyAPI interface {
|
||||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
||||||
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
||||||
|
QueryOneTimeCryptoIDs(ctx context.Context, userID string) (OneTimeCryptoIDsCount, *KeyError)
|
||||||
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -772,12 +774,25 @@ type OneTimeKeys struct {
|
||||||
KeyJSON map[string]json.RawMessage
|
KeyJSON map[string]json.RawMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OneTimeCryptoIDs struct {
|
||||||
|
// The user who owns this device
|
||||||
|
UserID string
|
||||||
|
// A map of algorithm:key_id => key JSON
|
||||||
|
KeyJSON map[string]json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
// Split a key in KeyJSON into algorithm and key ID
|
// Split a key in KeyJSON into algorithm and key ID
|
||||||
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||||
segments := strings.Split(keyIDWithAlgo, ":")
|
segments := strings.Split(keyIDWithAlgo, ":")
|
||||||
return segments[0], segments[1]
|
return segments[0], segments[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Split a key in KeyJSON into algorithm and key ID
|
||||||
|
func (k *OneTimeCryptoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||||
|
segments := strings.Split(keyIDWithAlgo, ":")
|
||||||
|
return segments[0], segments[1]
|
||||||
|
}
|
||||||
|
|
||||||
// OneTimeKeysCount represents the counts of one-time keys for a single device
|
// OneTimeKeysCount represents the counts of one-time keys for a single device
|
||||||
type OneTimeKeysCount struct {
|
type OneTimeKeysCount struct {
|
||||||
// The user who owns this device
|
// The user who owns this device
|
||||||
|
@ -792,12 +807,23 @@ type OneTimeKeysCount struct {
|
||||||
KeyCount map[string]int
|
KeyCount map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OneTimeCryptoIDsCount struct {
|
||||||
|
// The user who owns this device
|
||||||
|
UserID string
|
||||||
|
// algorithm to count e.g:
|
||||||
|
// {
|
||||||
|
// "pseudoid_curve25519": 10,
|
||||||
|
// }
|
||||||
|
KeyCount map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
||||||
type PerformUploadKeysRequest struct {
|
type PerformUploadKeysRequest struct {
|
||||||
UserID string // Required - User performing the request
|
UserID string // Required - User performing the request
|
||||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||||
DeviceKeys []DeviceKeys
|
DeviceKeys []DeviceKeys
|
||||||
OneTimeKeys []OneTimeKeys
|
OneTimeKeys []OneTimeKeys
|
||||||
|
OneTimeCryptoIDs []OneTimeCryptoIDs
|
||||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||||
// the display name for their respective device, and NOT to modify the keys. The key
|
// the display name for their respective device, and NOT to modify the keys. The key
|
||||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||||
|
@ -812,6 +838,7 @@ type PerformUploadKeysResponse struct {
|
||||||
// A map of user_id -> device_id -> Error for tracking failures.
|
// A map of user_id -> device_id -> Error for tracking failures.
|
||||||
KeyErrors map[string]map[string]*KeyError
|
KeyErrors map[string]map[string]*KeyError
|
||||||
OneTimeKeyCounts []OneTimeKeysCount
|
OneTimeKeyCounts []OneTimeKeysCount
|
||||||
|
OneTimeCryptoIDCounts []OneTimeCryptoIDsCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||||
|
|
|
@ -647,7 +647,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
|
||||||
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
||||||
user := ""
|
user := ""
|
||||||
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||||
if err == nil {
|
if err == nil && sender != nil {
|
||||||
user = sender.String()
|
user = sender.String()
|
||||||
}
|
}
|
||||||
if user == mem.UserID {
|
if user == mem.UserID {
|
||||||
|
|
|
@ -17,9 +17,11 @@ package internal
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -55,11 +57,21 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
||||||
if len(req.OneTimeKeys) > 0 {
|
if len(req.OneTimeKeys) > 0 {
|
||||||
a.uploadOneTimeKeys(ctx, req, res)
|
a.uploadOneTimeKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
|
if len(req.OneTimeCryptoIDs) > 0 {
|
||||||
|
a.uploadOneTimeCryptoIDs(ctx, req, res)
|
||||||
|
}
|
||||||
|
logrus.Infof("One time cryptoIDs count before: %v", res.OneTimeCryptoIDCounts)
|
||||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||||
|
otpIDs, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.OneTimeCryptoIDCounts = []api.OneTimeCryptoIDsCount{*otpIDs}
|
||||||
|
logrus.Infof("One time cryptoIDs count after: %v", res.OneTimeCryptoIDCounts)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,6 +193,17 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (api.OneTimeCryptoIDsCount, *api.KeyError) {
|
||||||
|
count, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return api.OneTimeCryptoIDsCount{}, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("Failed to query OTID counts: %s", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *count, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
||||||
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -773,6 +796,98 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) uploadOneTimeCryptoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
|
if req.UserID == "" {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: "user ID missing",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(req.OneTimeCryptoIDs) == 0 {
|
||||||
|
counts, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("a.KeyDatabase.OneTimeCryptoIDsCount: %s", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if counts != nil {
|
||||||
|
logrus.Infof("Uploading one-time cryptoIDs: early result count: %v", *counts)
|
||||||
|
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, key := range req.OneTimeCryptoIDs {
|
||||||
|
// grab existing keys based on (user/algorithm/key ID)
|
||||||
|
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||||
|
i := 0
|
||||||
|
for keyIDWithAlgo := range key.KeyJSON {
|
||||||
|
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
existingKeys, err := a.KeyDatabase.ExistingOneTimeCryptoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
|
||||||
|
if err != nil {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: "failed to query existing one-time cryptoIDs: " + err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for keyIDWithAlgo := range existingKeys {
|
||||||
|
// if keys exist and the JSON doesn't match, error out as the key already exists
|
||||||
|
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time cryptoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// store one-time keys
|
||||||
|
counts, err := a.KeyDatabase.StoreOneTimeCryptoIDs(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("%s device %s : failed to store one-time cryptoIDs: %s", req.UserID, req.DeviceID, err.Error()),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// collect counts
|
||||||
|
logrus.Infof("Uploading one-time cryptoIDs: result count: %v", *counts)
|
||||||
|
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Ed25519Key struct {
|
||||||
|
Key spec.Base64Bytes `json:"key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
|
cryptoID, err := a.KeyDatabase.ClaimOneTimeCryptoID(ctx, userID, "ed25519")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Infof("Claimed one time cryptoID: %s", cryptoID)
|
||||||
|
|
||||||
|
if cryptoID != nil {
|
||||||
|
for key, value := range cryptoID.KeyJSON {
|
||||||
|
keyParts := strings.Split(key, ":")
|
||||||
|
if keyParts[0] == "ed25519" {
|
||||||
|
var key_bytes Ed25519Key
|
||||||
|
err := json.Unmarshal(value, &key_bytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
length := len(key_bytes.Key)
|
||||||
|
if length != ed25519.PublicKeySize {
|
||||||
|
return "", fmt.Errorf("Invalid ed25519 public key, %d is the wrong size", length)
|
||||||
|
}
|
||||||
|
// TODO: cryptoIDs - store senderID for this user here?
|
||||||
|
return spec.SenderID(key_bytes.Key.Encode()), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("failed claiming a valid one time cryptoID for this user: %s", userID.String())
|
||||||
|
}
|
||||||
|
|
||||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||||
// if we only want to update the display names, we can skip the checks below
|
// if we only want to update the display names, we can skip the checks below
|
||||||
if onlyUpdateDisplayName {
|
if onlyUpdateDisplayName {
|
||||||
|
|
|
@ -175,6 +175,11 @@ type KeyDatabase interface {
|
||||||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
|
ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
|
||||||
|
OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
|
||||||
|
ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error)
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
|
|
191
userapi/storage/postgres/one_time_cryptoids_table.go
Normal file
191
userapi/storage/postgres/one_time_cryptoids_table.go
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var oneTimeCryptoIDsSchema = `
|
||||||
|
-- Stores one-time cryptoIDs for users
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
key_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL,
|
||||||
|
key_json TEXT NOT NULL,
|
||||||
|
-- Clobber based on 3-uple of user/key/algorithm.
|
||||||
|
CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE (user_id, key_id, algorithm)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertCryptoIDsSQL = "" +
|
||||||
|
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
|
" ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" +
|
||||||
|
" DO UPDATE SET key_json = $5"
|
||||||
|
|
||||||
|
const selectOneTimeCryptoIDsSQL = "" +
|
||||||
|
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
|
||||||
|
|
||||||
|
const selectCryptoIDsCountSQL = "" +
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||||
|
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
|
||||||
|
" x GROUP BY algorithm"
|
||||||
|
|
||||||
|
const deleteOneTimeCryptoIDSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||||
|
|
||||||
|
const selectCryptoIDByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||||
|
|
||||||
|
const deleteOneTimeCryptoIDsSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
|
||||||
|
|
||||||
|
type oneTimeCryptoIDsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertCryptoIDsStmt *sql.Stmt
|
||||||
|
selectCryptoIDsStmt *sql.Stmt
|
||||||
|
selectCryptoIDsCountStmt *sql.Stmt
|
||||||
|
selectCryptoIDByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteOneTimeCryptoIDStmt *sql.Stmt
|
||||||
|
deleteOneTimeCryptoIDsStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
|
||||||
|
s := &oneTimeCryptoIDsStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(oneTimeCryptoIDsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
|
||||||
|
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
|
||||||
|
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
|
||||||
|
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
|
||||||
|
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
|
||||||
|
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
|
||||||
|
|
||||||
|
result := make(map[string]json.RawMessage)
|
||||||
|
var (
|
||||||
|
algorithmWithID string
|
||||||
|
keyJSONStr string
|
||||||
|
)
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||||
|
counts := &api.OneTimeCryptoIDsCount{
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
counts := &api.OneTimeCryptoIDsCount{
|
||||||
|
UserID: keys.UserID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
|
||||||
|
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
return counts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -149,6 +149,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
otpid, err := NewPostgresOneTimeCryptoIDsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
dk, err := NewPostgresDeviceKeysTable(db)
|
dk, err := NewPostgresDeviceKeysTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -172,6 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
|
|
||||||
return &shared.KeyDatabase{
|
return &shared.KeyDatabase{
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
|
OneTimeCryptoIDsTable: otpid,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
StaleDeviceListsTable: sdl,
|
StaleDeviceListsTable: sdl,
|
||||||
|
|
|
@ -65,6 +65,7 @@ type Database struct {
|
||||||
|
|
||||||
type KeyDatabase struct {
|
type KeyDatabase struct {
|
||||||
OneTimeKeysTable tables.OneTimeKeys
|
OneTimeKeysTable tables.OneTimeKeys
|
||||||
|
OneTimeCryptoIDsTable tables.OneTimeCryptoIDs
|
||||||
DeviceKeysTable tables.DeviceKeys
|
DeviceKeysTable tables.DeviceKeys
|
||||||
KeyChangesTable tables.KeyChanges
|
KeyChangesTable tables.KeyChanges
|
||||||
StaleDeviceListsTable tables.StaleDeviceLists
|
StaleDeviceListsTable tables.StaleDeviceLists
|
||||||
|
@ -945,6 +946,40 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
|
||||||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
return d.OneTimeCryptoIDsTable.SelectOneTimeCryptoIDs(ctx, userID, keyIDsWithAlgorithms)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (counts *api.OneTimeCryptoIDsCount, err error) {
|
||||||
|
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
counts, err = d.OneTimeCryptoIDsTable.InsertOneTimeCryptoIDs(ctx, txn, keys)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||||
|
return d.OneTimeCryptoIDsTable.CountOneTimeCryptoIDs(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) {
|
||||||
|
var result *api.OneTimeCryptoIDs
|
||||||
|
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
keyJSON, err := d.OneTimeCryptoIDsTable.SelectAndDeleteOneTimeCryptoID(ctx, txn, userID.String(), algorithm)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if keyJSON != nil {
|
||||||
|
result = &api.OneTimeCryptoIDs{
|
||||||
|
UserID: userID.String(),
|
||||||
|
KeyJSON: keyJSON,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
208
userapi/storage/sqlite3/one_time_cryptoids_table.go
Normal file
208
userapi/storage/sqlite3/one_time_cryptoids_table.go
Normal file
|
@ -0,0 +1,208 @@
|
||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var oneTimeCryptoIDsSchema = `
|
||||||
|
-- Stores one-time cryptoIDs for users
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
key_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL,
|
||||||
|
key_json TEXT NOT NULL,
|
||||||
|
-- Clobber based on 3-uple of user/key/algorithm.
|
||||||
|
UNIQUE (user_id, key_id, algorithm)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertCryptoIDsSQL = "" +
|
||||||
|
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
|
" ON CONFLICT (user_id, key_id, algorithm)" +
|
||||||
|
" DO UPDATE SET key_json = $5"
|
||||||
|
|
||||||
|
const selectOneTimeCryptoIDsSQL = "" +
|
||||||
|
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1"
|
||||||
|
|
||||||
|
const selectCryptoIDsCountSQL = "" +
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||||
|
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
|
||||||
|
" x GROUP BY algorithm"
|
||||||
|
|
||||||
|
const deleteOneTimeCryptoIDSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||||
|
|
||||||
|
const selectCryptoIDByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||||
|
|
||||||
|
const deleteOneTimeCryptoIDsSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
|
||||||
|
|
||||||
|
type oneTimeCryptoIDsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertCryptoIDsStmt *sql.Stmt
|
||||||
|
selectCryptoIDsStmt *sql.Stmt
|
||||||
|
selectCryptoIDsCountStmt *sql.Stmt
|
||||||
|
selectCryptoIDByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteOneTimeCryptoIDStmt *sql.Stmt
|
||||||
|
deleteOneTimeCryptoIDsStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
|
||||||
|
s := &oneTimeCryptoIDsStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(oneTimeCryptoIDsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
|
||||||
|
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
|
||||||
|
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
|
||||||
|
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
|
||||||
|
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
|
||||||
|
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
|
||||||
|
|
||||||
|
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||||
|
for _, ka := range keyIDsWithAlgorithms {
|
||||||
|
wantSet[ka] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]json.RawMessage)
|
||||||
|
for rows.Next() {
|
||||||
|
var keyID string
|
||||||
|
var algorithm string
|
||||||
|
var keyJSONStr string
|
||||||
|
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
keyIDWithAlgo := algorithm + ":" + keyID
|
||||||
|
if wantSet[keyIDWithAlgo] {
|
||||||
|
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||||
|
counts := &api.OneTimeCryptoIDsCount{
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs,
|
||||||
|
) (*api.OneTimeCryptoIDsCount, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
counts := &api.OneTimeCryptoIDsCount{
|
||||||
|
UserID: keys.UserID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
|
||||||
|
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
return counts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
logrus.Warnf("No rows found for one time cryptoIDs")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if keyJSON == "" {
|
||||||
|
logrus.Warnf("Empty key JSON for one time cryptoIDs")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -146,6 +146,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
otpid, err := NewSqliteOneTimeCryptoIDsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
dk, err := NewSqliteDeviceKeysTable(db)
|
dk, err := NewSqliteDeviceKeysTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -169,6 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
|
|
||||||
return &shared.KeyDatabase{
|
return &shared.KeyDatabase{
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
|
OneTimeCryptoIDsTable: otpid,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
StaleDeviceListsTable: sdl,
|
StaleDeviceListsTable: sdl,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package storage_test
|
package storage_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOneTimeCryptoIDs(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||||
|
defer clean()
|
||||||
|
userID := "@alice:localhost"
|
||||||
|
otk := api.OneTimeCryptoIDs{
|
||||||
|
UserID: userID,
|
||||||
|
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a one time pseudoID to the DB
|
||||||
|
_, err := db.StoreOneTimeCryptoIDs(ctx, otk)
|
||||||
|
MustNotError(t, err)
|
||||||
|
|
||||||
|
// Check the count of one time pseudoIDs is correct
|
||||||
|
count, err := db.OneTimeCryptoIDsCount(ctx, userID)
|
||||||
|
MustNotError(t, err)
|
||||||
|
if count.KeyCount["pseudoid_curve25519"] != 1 {
|
||||||
|
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the actual pseudoid contents are correct
|
||||||
|
keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
|
||||||
|
MustNotError(t, err)
|
||||||
|
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
|
||||||
|
MustNotError(t, err)
|
||||||
|
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
|
||||||
|
t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -168,6 +168,14 @@ type OneTimeKeys interface {
|
||||||
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OneTimeCryptoIDs interface {
|
||||||
|
SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
|
||||||
|
InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
|
||||||
|
SelectAndDeleteOneTimeCryptoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
|
||||||
|
DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error
|
||||||
|
}
|
||||||
|
|
||||||
type DeviceKeys interface {
|
type DeviceKeys interface {
|
||||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||||
|
|
Loading…
Reference in a new issue