Merge branch 'master' into single-event-retrieval-671

This commit is contained in:
Cnly 2019-08-06 23:27:06 +08:00
commit 64685aa416
38 changed files with 634 additions and 213 deletions

View file

@ -1,49 +0,0 @@
steps:
- command:
# https://github.com/golangci/golangci-lint#memory-usage-of-golangci-lint
- "GOGC=20 ./scripts/find-lint.sh"
label: "\U0001F9F9 Lint / :go: 1.12"
agents:
# Use a larger instance as linting takes a looot of memory
queue: "medium"
plugins:
- docker#v3.0.1:
image: "golang:1.12"
- wait
- command:
- "go build ./cmd/..."
label: "\U0001F528 Build / :go: 1.11"
plugins:
- docker#v3.0.1:
image: "golang:1.11"
retry:
automatic:
- exit_status: 128
limit: 3
- command:
- "go build ./cmd/..."
label: "\U0001F528 Build / :go: 1.12"
plugins:
- docker#v3.0.1:
image: "golang:1.12"
retry:
automatic:
- exit_status: 128
limit: 3
- command:
- "go test ./..."
label: "\U0001F9EA Unit tests / :go: 1.11"
plugins:
- docker#v3.0.1:
image: "golang:1.11"
- command:
- "go test ./..."
label: "\U0001F9EA Unit tests / :go: 1.12"
plugins:
- docker#v3.0.1:
image: "golang:1.12"

View file

@ -17,6 +17,7 @@ package accounts
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -71,25 +72,44 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
func (s *filterStatements) selectFilter( func (s *filterStatements) selectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) (filter []byte, err error) { ) (*gomatrixserverlib.Filter, error) {
err = s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filter) // Retrieve filter from database (stored as canonical JSON)
return var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil {
return nil, err
}
// Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}
return &filter, nil
} }
func (s *filterStatements) insertFilter( func (s *filterStatements) insertFilter(
ctx context.Context, filter []byte, localpart string, ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string
// This can result in a race condition when two clients try to insert the // Serialise json
// same filter and localpart at the same time, however this is not a filterJSON, err := json.Marshal(filter)
// problem as both calls will result in the same filterID if err != nil {
filterJSON, err := gomatrixserverlib.CanonicalJSON(filter) return "", err
}
// Remove whitespaces and sort JSON data
// needed to prevent from inserting the same filter multiple times
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
if err != nil { if err != nil {
return "", err return "", err
} }
// Check if filter already exists in the database // Check if filter already exists in the database using its localpart and content
//
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID) localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {

View file

@ -230,7 +230,7 @@ func (d *Database) newMembership(
} }
// Only "join" membership events can be considered as new memberships // Only "join" membership events can be considered as new memberships
if membership == "join" { if membership == gomatrixserverlib.Join {
if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
return err return err
} }
@ -344,11 +344,11 @@ func (d *Database) GetThreePIDsForLocalpart(
} }
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter represented as a byte slice. Otherwise returns an error if // Returns a filter structure. Otherwise returns an error if no such filter exists
// no such filter exists or if there was an error talking to the database. // or if there was an error talking to the database.
func (d *Database) GetFilter( func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) ([]byte, error) { ) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID) return d.filter.selectFilter(ctx, localpart, filterID)
} }
@ -356,7 +356,7 @@ func (d *Database) GetFilter(
// Returns the filterID as a string. Otherwise returns an error if something // Returns the filterID as a string. Otherwise returns an error if something
// goes wrong. // goes wrong.
func (d *Database) PutFilter( func (d *Database) PutFilter(
ctx context.Context, localpart string, filter []byte, ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) { ) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart) return d.filter.insertFilter(ctx, filter, localpart)
} }

View file

@ -169,6 +169,8 @@ func (s *devicesStatements) selectDeviceByToken(
return &dev, err return &dev, err
} }
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) { ) (*authtypes.Device, error) {

View file

@ -84,7 +84,7 @@ func (d *Database) CreateDevice(
if deviceID != nil { if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing token for this device // Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err return err
} }

View file

@ -15,6 +15,7 @@
package routing package routing
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -54,10 +55,6 @@ const (
presetPublicChat = "public_chat" presetPublicChat = "public_chat"
) )
const (
joinRulePublic = "public"
joinRuleInvite = "invite"
)
const ( const (
historyVisibilityShared = "shared" historyVisibilityShared = "shared"
// TODO: These should be implemented once history visibility is implemented // TODO: These should be implemented once history visibility is implemented
@ -97,6 +94,27 @@ func (r createRoomRequest) Validate() *util.JSONResponse {
} }
} }
// Validate creation_content fields defined in the spec by marshalling the
// creation_content map into bytes and then unmarshalling the bytes into
// common.CreateContent.
creationContentBytes, err := json.Marshal(r.CreationContent)
if err != nil {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed creation_content"),
}
}
var CreationContent common.CreateContent
err = json.Unmarshal(creationContentBytes, &CreationContent)
if err != nil {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed creation_content"),
}
}
return nil return nil
} }
@ -154,7 +172,17 @@ func createRoom(
JSON: jsonerror.InvalidArgumentValue(err.Error()), JSON: jsonerror.InvalidArgumentValue(err.Error()),
} }
} }
// TODO: visibility/presets/raw initial state/creation content
// Clobber keys: creator, room_version
if r.CreationContent == nil {
r.CreationContent = make(map[string]interface{}, 2)
}
r.CreationContent["creator"] = userID
r.CreationContent["room_version"] = "1" // TODO: We set this to 1 before we support Room versioning
// TODO: visibility/presets/raw initial state
// TODO: Create room alias association // TODO: Create room alias association
// Make sure this doesn't fall into an application service's namespace though! // Make sure this doesn't fall into an application service's namespace though!
@ -169,7 +197,7 @@ func createRoom(
} }
membershipContent := common.MemberContent{ membershipContent := common.MemberContent{
Membership: "join", Membership: gomatrixserverlib.Join,
DisplayName: profile.DisplayName, DisplayName: profile.DisplayName,
AvatarURL: profile.AvatarURL, AvatarURL: profile.AvatarURL,
} }
@ -177,19 +205,19 @@ func createRoom(
var joinRules, historyVisibility string var joinRules, historyVisibility string
switch r.Preset { switch r.Preset {
case presetPrivateChat: case presetPrivateChat:
joinRules = joinRuleInvite joinRules = gomatrixserverlib.Invite
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
case presetTrustedPrivateChat: case presetTrustedPrivateChat:
joinRules = joinRuleInvite joinRules = gomatrixserverlib.Invite
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
// TODO If trusted_private_chat, all invitees are given the same power level as the room creator. // TODO If trusted_private_chat, all invitees are given the same power level as the room creator.
case presetPublicChat: case presetPublicChat:
joinRules = joinRulePublic joinRules = gomatrixserverlib.Public
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
default: default:
// Default room rules, r.Preset was previously checked for valid values so // Default room rules, r.Preset was previously checked for valid values so
// only a request with no preset should end up here. // only a request with no preset should end up here.
joinRules = joinRuleInvite joinRules = gomatrixserverlib.Invite
historyVisibility = historyVisibilityShared historyVisibility = historyVisibilityShared
} }
@ -214,7 +242,7 @@ func createRoom(
// harder to reason about, hence sticking to a strict static ordering. // harder to reason about, hence sticking to a strict static ordering.
// TODO: Synapse has txn/token ID on each event. Do we need to do this here? // TODO: Synapse has txn/token ID on each event. Do we need to do this here?
eventsToMake := []fledglingEvent{ eventsToMake := []fledglingEvent{
{"m.room.create", "", common.CreateContent{Creator: userID}}, {"m.room.create", "", r.CreationContent},
{"m.room.member", userID, membershipContent}, {"m.room.member", userID, membershipContent},
{"m.room.power_levels", "", common.InitialPowerLevelsContent(userID)}, {"m.room.power_levels", "", common.InitialPowerLevelsContent(userID)},
// TODO: m.room.canonical_alias // TODO: m.room.canonical_alias

View file

@ -17,13 +17,10 @@ package routing
import ( import (
"net/http" "net/http"
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -43,7 +40,7 @@ func GetFilter(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
res, err := accountDB.GetFilter(req.Context(), localpart, filterID) filter, err := accountDB.GetFilter(req.Context(), localpart, filterID)
if err != nil { if err != nil {
//TODO better error handling. This error message is *probably* right, //TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned, // but if there are obscure db errors, this will also be returned,
@ -53,11 +50,6 @@ func GetFilter(
JSON: jsonerror.NotFound("No such filter"), JSON: jsonerror.NotFound("No such filter"),
} }
} }
filter := gomatrix.Filter{}
err = json.Unmarshal(res, &filter)
if err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
@ -85,21 +77,21 @@ func PutFilter(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
var filter gomatrix.Filter var filter gomatrixserverlib.Filter
if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil { if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil {
return *reqErr return *reqErr
} }
filterArray, err := json.Marshal(filter) // Validate generates a user-friendly error
if err != nil { if err = filter.Validate(); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Filter is malformed"), JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()),
} }
} }
filterID, err := accountDB.PutFilter(req.Context(), localpart, filterArray) filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -70,7 +70,7 @@ func JoinRoomByIDOrAlias(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
content["membership"] = "join" content["membership"] = gomatrixserverlib.Join
content["displayname"] = profile.DisplayName content["displayname"] = profile.DisplayName
content["avatar_url"] = profile.AvatarURL content["avatar_url"] = profile.AvatarURL

View file

@ -18,7 +18,6 @@ import (
"net/http" "net/http"
"context" "context"
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -44,8 +43,10 @@ type flow struct {
type passwordRequest struct { type passwordRequest struct {
User string `json:"user"` User string `json:"user"`
Password string `json:"password"` Password string `json:"password"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID string `json:"device_id"` DeviceID *string `json:"device_id"`
} }
type loginResponse struct { type loginResponse struct {
@ -110,7 +111,7 @@ func Login(
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }
dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token) dev, err := getDevice(req.Context(), r, deviceDB, acc, token)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -134,20 +135,16 @@ func Login(
} }
} }
// check if device exists else create one // getDevice returns a new or existing device
func getDevice( func getDevice(
ctx context.Context, ctx context.Context,
r passwordRequest, r passwordRequest,
deviceDB *devices.Database, deviceDB *devices.Database,
acc *authtypes.Account, acc *authtypes.Account,
localpart, token string, token string,
) (dev *authtypes.Device, err error) { ) (dev *authtypes.Device, err error) {
dev, err = deviceDB.GetDeviceByID(ctx, localpart, r.DeviceID)
if err == sql.ErrNoRows {
// device doesn't exist, create one
dev, err = deviceDB.CreateDevice( dev, err = deviceDB.CreateDevice(
ctx, acc.Localpart, nil, token, r.InitialDisplayName, ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName,
) )
}
return return
} }

View file

@ -102,7 +102,7 @@ func SendMembership(
var returnData interface{} = struct{}{} var returnData interface{} = struct{}{}
// The join membership requires the room id to be sent in the response // The join membership requires the room id to be sent in the response
if membership == "join" { if membership == gomatrixserverlib.Join {
returnData = struct { returnData = struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
}{roomID} }{roomID}
@ -141,7 +141,7 @@ func buildMembershipEvent(
// "unban" or "kick" isn't a valid membership value, change it to "leave" // "unban" or "kick" isn't a valid membership value, change it to "leave"
if membership == "unban" || membership == "kick" { if membership == "unban" || membership == "kick" {
membership = "leave" membership = gomatrixserverlib.Leave
} }
content := common.MemberContent{ content := common.MemberContent{
@ -192,7 +192,7 @@ func loadProfile(
func getMembershipStateKey( func getMembershipStateKey(
body threepid.MembershipRequest, device *authtypes.Device, membership string, body threepid.MembershipRequest, device *authtypes.Device, membership string,
) (stateKey string, reason string, err error) { ) (stateKey string, reason string, err error) {
if membership == "ban" || membership == "unban" || membership == "kick" || membership == "invite" { if membership == gomatrixserverlib.Ban || membership == "unban" || membership == "kick" || membership == gomatrixserverlib.Invite {
// If we're in this case, the state key is contained in the request body, // If we're in this case, the state key is contained in the request body,
// possibly along with a reason (for "kick" and "ban") so we need to parse // possibly along with a reason (for "kick" and "ban") so we need to parse
// it // it

View file

@ -264,7 +264,7 @@ func buildMembershipEvents(
} }
content := common.MemberContent{ content := common.MemberContent{
Membership: "join", Membership: gomatrixserverlib.Join,
} }
content.DisplayName = newProfile.DisplayName content.DisplayName = newProfile.DisplayName

View file

@ -121,7 +121,10 @@ type registerRequest struct {
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID *string `json:"device_id"`
// Prevent this user from logging in // Prevent this user from logging in
InhibitLogin common.WeakBoolean `json:"inhibit_login"` InhibitLogin common.WeakBoolean `json:"inhibit_login"`
@ -626,7 +629,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, req.Context(), accountDB, deviceDB, r.Username, "", appserviceID,
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -646,7 +649,7 @@ func checkAndCompleteFlow(
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, r.Password, "", req.Context(), accountDB, deviceDB, r.Username, r.Password, "",
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -697,10 +700,10 @@ func LegacyRegister(
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
} }
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotImplemented, Code: http.StatusNotImplemented,
@ -738,13 +741,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
return nil return nil
} }
// completeRegistration runs some rudimentary checks against the submitted
// input, then if successful creates an account and a newly associated device
// We pass in each individual part of the request here instead of just passing a
// registerRequest, as this function serves requests encoded as both
// registerRequests and legacyRegisterRequests, which share some attributes but
// not all
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
accountDB *accounts.Database, accountDB *accounts.Database,
deviceDB *devices.Database, deviceDB *devices.Database,
username, password, appserviceID string, username, password, appserviceID string,
inhibitLogin common.WeakBoolean, inhibitLogin common.WeakBoolean,
displayName *string, displayName, deviceID *string,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -773,6 +782,9 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
@ -793,8 +805,7 @@ func completeRegistration(
} }
} }
// TODO: Use the device ID in the request. dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName)
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -802,9 +813,6 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{

View file

@ -0,0 +1,234 @@
// Copyright 2019 Sumukha PK
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"encoding/json"
"net/http"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// newTag creates and returns a new gomatrix.TagContent
func newTag() gomatrix.TagContent {
return gomatrix.TagContent{
Tags: make(map[string]gomatrix.TagProperties),
}
}
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags(
req *http.Request,
accountDB *accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if device.UserID != userID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot retrieve another user's tags"),
}
}
_, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil {
return httputil.LogThenError(req, err)
}
if len(data) == 0 {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: data[0].Content,
}
}
// PutTag implements PUT /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags/{tag}
// Put functionality works by getting existing data from the DB (if any), adding
// the tag to the "map" and saving the new "map" to the DB
func PutTag(
req *http.Request,
accountDB *accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
tag string,
syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if device.UserID != userID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot modify another user's tags"),
}
}
var properties gomatrix.TagProperties
if reqErr := httputil.UnmarshalJSONRequest(req, &properties); reqErr != nil {
return *reqErr
}
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil {
return httputil.LogThenError(req, err)
}
var tagContent gomatrix.TagContent
if len(data) > 0 {
if err = json.Unmarshal(data[0].Content, &tagContent); err != nil {
return httputil.LogThenError(req, err)
}
} else {
tagContent = newTag()
}
tagContent.Tags[tag] = properties
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err)
}
// Send data to syncProducer in order to inform clients of changes
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// DeleteTag implements DELETE /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags/{tag}
// Delete functionality works by obtaining the saved tags, removing the intended tag from
// the "map" and then saving the new "map" in the DB
func DeleteTag(
req *http.Request,
accountDB *accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
tag string,
syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if device.UserID != userID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Cannot modify another user's tags"),
}
}
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
if err != nil {
return httputil.LogThenError(req, err)
}
// If there are no tags in the database, exit
if len(data) == 0 {
// Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
var tagContent gomatrix.TagContent
err = json.Unmarshal(data[0].Content, &tagContent)
if err != nil {
return httputil.LogThenError(req, err)
}
// Check whether the tag to be deleted exists
if _, ok := tagContent.Tags[tag]; ok {
delete(tagContent.Tags, tag)
} else {
// Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
return httputil.LogThenError(req, err)
}
// Send data to syncProducer in order to inform clients of changes
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
// obtainSavedTags gets all tags scoped to a userID and roomID
// from the database
func obtainSavedTags(
req *http.Request,
userID string,
roomID string,
accountDB *accounts.Database,
) (string, []gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return "", nil, err
}
data, err := accountDB.GetAccountDataByType(
req.Context(), localpart, roomID, "m.tag",
)
return localpart, data, err
}
// saveTagData saves the provided tag data into the database
func saveTagData(
req *http.Request,
localpart string,
roomID string,
accountDB *accounts.Database,
Tag gomatrix.TagContent,
) error {
newTagData, err := json.Marshal(Tag)
if err != nil {
return err
}
return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData))
}

View file

@ -93,7 +93,7 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}", r0mux.Handle("/join/{roomIDOrAlias}",
common.MakeAuthAPI("join", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI(gomatrixserverlib.Join, authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -492,4 +492,34 @@ func Setup(
}} }}
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags",
common.MakeAuthAPI("get_tags", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
common.MakeAuthAPI("put_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}",
common.MakeAuthAPI("delete_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
).Methods(http.MethodDelete, http.MethodOptions)
} }

View file

@ -91,7 +91,7 @@ func CheckAndProcessInvite(
producer *producers.RoomserverProducer, membership string, roomID string, producer *producers.RoomserverProducer, membership string, roomID string,
evTime time.Time, evTime time.Time,
) (inviteStoredOnIDServer bool, err error) { ) (inviteStoredOnIDServer bool, err error) {
if membership != "invite" || (body.Address == "" && body.IDServer == "" && body.Medium == "") { if membership != gomatrixserverlib.Invite || (body.Address == "" && body.IDServer == "" && body.Medium == "") {
// If none of the 3PID-specific fields are supplied, it's a standard invite // If none of the 3PID-specific fields are supplied, it's a standard invite
// so return nil for it to be processed as such // so return nil for it to be processed as such
return return

View file

@ -86,7 +86,7 @@ func main() {
// Build a m.room.member event. // Build a m.room.member event.
b.Type = "m.room.member" b.Type = "m.room.member"
b.StateKey = userID b.StateKey = userID
b.SetContent(map[string]string{"membership": "join"}) // nolint: errcheck b.SetContent(map[string]string{"membership": gomatrixserverlib.Join}) // nolint: errcheck
b.AuthEvents = []gomatrixserverlib.EventReference{create} b.AuthEvents = []gomatrixserverlib.EventReference{create}
member := buildAndOutput() member := buildAndOutput()

View file

@ -54,12 +54,14 @@ database:
server_key: "postgresql:///server_keys" server_key: "postgresql:///server_keys"
sync_api: "postgresql:///syn_api" sync_api: "postgresql:///syn_api"
room_server: "postgresql:///room_server" room_server: "postgresql:///room_server"
appservice: "postgresql:///appservice"
listen: listen:
room_server: "localhost:7770" room_server: "localhost:7770"
client_api: "localhost:7771" client_api: "localhost:7771"
federation_api: "localhost:7772" federation_api: "localhost:7772"
sync_api: "localhost:7773" sync_api: "localhost:7773"
media_api: "localhost:7774" media_api: "localhost:7774"
appservice_api: "localhost:7777"
typing_server: "localhost:7778" typing_server: "localhost:7778"
logging: logging:
- type: "file" - type: "file"

View file

@ -18,6 +18,14 @@ package common
type CreateContent struct { type CreateContent struct {
Creator string `json:"creator"` Creator string `json:"creator"`
Federate *bool `json:"m.federate,omitempty"` Federate *bool `json:"m.federate,omitempty"`
RoomVersion string `json:"room_version,omitempty"`
Predecessor PreviousRoom `json:"predecessor,omitempty"`
}
// PreviousRoom is the "Previous Room" structure defined at https://matrix.org/docs/spec/client_server/r0.5.0#m-room-create
type PreviousRoom struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
} }
// MemberContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member // MemberContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member

View file

@ -15,9 +15,12 @@
package common package common
import ( import (
"fmt"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"runtime"
"strings"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dugong" "github.com/matrix-org/dugong"
@ -54,15 +57,35 @@ func (h *logLevelHook) Levels() []logrus.Level {
return levels return levels
} }
// callerPrettyfier is a function that given a runtime.Frame object, will
// extract the calling function's name and file, and return them in a nicely
// formatted way
func callerPrettyfier(f *runtime.Frame) (string, string) {
// Retrieve just the function name
s := strings.Split(f.Function, ".")
funcname := s[len(s)-1]
// Append a newline + tab to it to move the actual log content to its own line
funcname += "\n\t"
// Surround the filepath in brackets and append line number so IDEs can quickly
// navigate
filename := fmt.Sprintf(" [%s:%d]", f.File, f.Line)
return funcname, filename
}
// SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded.
func SetupStdLogging() { func SetupStdLogging() {
logrus.SetReportCaller(true)
logrus.SetFormatter(&utcFormatter{ logrus.SetFormatter(&utcFormatter{
&logrus.TextFormatter{ &logrus.TextFormatter{
TimestampFormat: "2006-01-02T15:04:05.000000000Z07:00", TimestampFormat: "2006-01-02T15:04:05.000000000Z07:00",
FullTimestamp: true, FullTimestamp: true,
DisableColors: false, DisableColors: false,
DisableTimestamp: false, DisableTimestamp: false,
DisableSorting: false, QuoteEmptyFields: true,
CallerPrettyfier: callerPrettyfier,
}, },
}) })
} }
@ -71,8 +94,8 @@ func SetupStdLogging() {
// If something fails here it means that the logging was improperly configured, // If something fails here it means that the logging was improperly configured,
// so we just exit with the error // so we just exit with the error
func SetupHookLogging(hooks []config.LogrusHook, componentName string) { func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
logrus.SetReportCaller(true)
for _, hook := range hooks { for _, hook := range hooks {
// Check we received a proper logging level // Check we received a proper logging level
level, err := logrus.ParseLevel(hook.Level) level, err := logrus.ParseLevel(hook.Level)
if err != nil { if err != nil {
@ -126,6 +149,7 @@ func setupFileHook(hook config.LogrusHook, level logrus.Level, componentName str
DisableColors: true, DisableColors: true,
DisableTimestamp: false, DisableTimestamp: false,
DisableSorting: false, DisableSorting: false,
QuoteEmptyFields: true,
}, },
}, },
&dugong.DailyRotationSchedule{GZip: true}, &dugong.DailyRotationSchedule{GZip: true},

View file

@ -58,7 +58,7 @@ func MakeJoin(
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &userID,
} }
err = builder.SetContent(map[string]interface{}{"membership": "join"}) err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join})
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }

View file

@ -56,7 +56,7 @@ func MakeLeave(
Type: "m.room.member", Type: "m.room.member",
StateKey: &userID, StateKey: &userID,
} }
err = builder.SetContent(map[string]interface{}{"membership": "leave"}) err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Leave})
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} }
@ -153,7 +153,7 @@ func SendLeave(
mem, err := event.Membership() mem, err := event.Membership()
if err != nil { if err != nil {
return httputil.LogThenError(httpReq, err) return httputil.LogThenError(httpReq, err)
} else if mem != "leave" { } else if mem != gomatrixserverlib.Leave {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"), JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"),

View file

@ -202,7 +202,7 @@ func createInviteFrom3PIDInvite(
content := common.MemberContent{ content := common.MemberContent{
AvatarURL: profile.AvatarURL, AvatarURL: profile.AvatarURL,
DisplayName: profile.DisplayName, DisplayName: profile.DisplayName,
Membership: "invite", Membership: gomatrixserverlib.Invite,
ThirdPartyInvite: &common.TPInvite{ ThirdPartyInvite: &common.TPInvite{
Signed: inv.Signed, Signed: inv.Signed,
}, },

View file

@ -233,7 +233,7 @@ func joinedHostsFromEvents(evs []gomatrixserverlib.Event) ([]types.JoinedHost, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
if membership != "join" { if membership != gomatrixserverlib.Join {
continue continue
} }
_, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) _, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())

10
go.mod
View file

@ -20,10 +20,11 @@ require (
github.com/jaegertracing/jaeger-client-go v0.0.0-20170921145708-3ad49a1d839b github.com/jaegertracing/jaeger-client-go v0.0.0-20170921145708-3ad49a1d839b
github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 github.com/matrix-org/gomatrixserverlib v0.0.0-20190805173246-3a2199d5ecd6
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
github.com/matttproud/golang_protobuf_extensions v1.0.1 github.com/matttproud/golang_protobuf_extensions v1.0.1
@ -40,8 +41,9 @@ require (
github.com/prometheus/common v0.0.0-20170108231212-dd2f054febf4 github.com/prometheus/common v0.0.0-20170108231212-dd2f054febf4
github.com/prometheus/procfs v0.0.0-20170128160123-1878d9fbb537 github.com/prometheus/procfs v0.0.0-20170128160123-1878d9fbb537
github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5
github.com/sirupsen/logrus v1.3.0 github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.2.2 github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.3.0
github.com/tidwall/gjson v1.1.5 github.com/tidwall/gjson v1.1.5
github.com/tidwall/match v1.0.1 github.com/tidwall/match v1.0.1
github.com/tidwall/sjson v1.0.3 github.com/tidwall/sjson v1.0.3
@ -54,7 +56,7 @@ require (
go.uber.org/zap v1.7.1 go.uber.org/zap v1.7.1
golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613 golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613
golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7
gopkg.in/Shopify/sarama.v1 v1.11.0 gopkg.in/Shopify/sarama.v1 v1.11.0
gopkg.in/airbrake/gobrake.v2 v2.0.9 gopkg.in/airbrake/gobrake.v2 v2.0.9
gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20170727041045-23bcc3c4eae3 gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20170727041045-23bcc3c4eae3

13
go.sum
View file

@ -36,6 +36,7 @@ github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe/go.mod h1
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg=
github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
@ -53,6 +54,10 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0 h1:3U
github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0/go.mod h1:YHyhIQUmuXyKtoVfDUMk/DyU93Taamlu6nPZkij/JtA= github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0/go.mod h1:YHyhIQUmuXyKtoVfDUMk/DyU93Taamlu6nPZkij/JtA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 h1:pYajAEdi3sowj4iSunqctchhcMNW3rDjeeH0T4uDkMY= github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 h1:pYajAEdi3sowj4iSunqctchhcMNW3rDjeeH0T4uDkMY=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E= github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190724145009-a6df10ef35d6 h1:B8n1H5Wb1B5jwLzTylBpY0kJCMRqrofT7PmOw4aJFJA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190724145009-a6df10ef35d6/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190805173246-3a2199d5ecd6 h1:xr69Hk6QM3RIN6JSvx3RpDowBGpHpDDqhqXCeySwYow=
github.com/matrix-org/gomatrixserverlib v0.0.0-20190805173246-3a2199d5ecd6/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A= github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A=
github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A= github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
github.com/matrix-org/util v0.0.0-20171013132526-8b1c8ab81986 h1:TiWl4hLvezAhRPM8tPcPDFTysZ7k4T/1J4GPp/iqlZo= github.com/matrix-org/util v0.0.0-20171013132526-8b1c8ab81986 h1:TiWl4hLvezAhRPM8tPcPDFTysZ7k4T/1J4GPp/iqlZo=
@ -90,9 +95,14 @@ github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2 h1:+8J/sCAVv2Y9Ct1
github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc=
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v0.0.0-20170809224252-890a5c3458b4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v0.0.0-20170809224252-890a5c3458b4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/tidwall/gjson v1.0.2 h1:5BsM7kyEAHAUGEGDkEKO9Mdyiuw6QQ6TSDdarP0Nnmk= github.com/tidwall/gjson v1.0.2 h1:5BsM7kyEAHAUGEGDkEKO9Mdyiuw6QQ6TSDdarP0Nnmk=
github.com/tidwall/gjson v1.0.2/go.mod h1:c/nTNbUr0E0OrXEhq1pwa8iEgc2DOt4ZZqAt1HtCkPA= github.com/tidwall/gjson v1.0.2/go.mod h1:c/nTNbUr0E0OrXEhq1pwa8iEgc2DOt4ZZqAt1HtCkPA=
github.com/tidwall/gjson v1.1.5 h1:QysILxBeUEY3GTLA0fQVgkQG1zme8NxGvhh2SSqWNwI= github.com/tidwall/gjson v1.1.5 h1:QysILxBeUEY3GTLA0fQVgkQG1zme8NxGvhh2SSqWNwI=
@ -128,6 +138,9 @@ golang.org/x/sys v0.0.0-20171012164349-43eea11bc926 h1:PY6OU86NqbyZiOzaPnDw6oOjA
golang.org/x/sys v0.0.0-20171012164349-43eea11bc926/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20171012164349-43eea11bc926/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7 h1:LepdCS8Gf/MVejFIt8lsiexZATdoGVyp5bcyS+rYoUI=
golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0= gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0=
gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc= gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U=

View file

@ -19,6 +19,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/publicroomsapi/storage" "github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -39,7 +40,7 @@ func GetVisibility(
var v roomVisibility var v roomVisibility
if isPublic { if isPublic {
v.Visibility = "public" v.Visibility = gomatrixserverlib.Public
} else { } else {
v.Visibility = "private" v.Visibility = "private"
} }
@ -61,7 +62,7 @@ func SetVisibility(
return *reqErr return *reqErr
} }
isPublic := v.Visibility == "public" isPublic := v.Visibility == gomatrixserverlib.Public
if err := publicRoomsDatabase.SetRoomVisibility(req.Context(), isPublic, roomID); err != nil { if err := publicRoomsDatabase.SetRoomVisibility(req.Context(), isPublic, roomID); err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
} }

View file

@ -42,8 +42,8 @@ type publicRoomRes struct {
Estimate int64 `json:"total_room_count_estimate,omitempty"` Estimate int64 `json:"total_room_count_estimate,omitempty"`
} }
// GetPublicRooms implements GET /publicRooms // GetPostPublicRooms implements GET and POST /publicRooms
func GetPublicRooms( func GetPostPublicRooms(
req *http.Request, publicRoomDatabase *storage.PublicRoomsServerDatabase, req *http.Request, publicRoomDatabase *storage.PublicRoomsServerDatabase,
) util.JSONResponse { ) util.JSONResponse {
var limit int16 var limit int16

View file

@ -64,7 +64,7 @@ func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB *storag
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/publicRooms", r0mux.Handle("/publicRooms",
common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
return directory.GetPublicRooms(req, publicRoomsDB) return directory.GetPostPublicRooms(req, publicRoomsDB)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
} }

View file

@ -185,7 +185,7 @@ func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
return err return err
} }
if membership != "join" { if membership != gomatrixserverlib.Join {
return nil return nil
} }

View file

@ -23,7 +23,7 @@ func IsServerAllowed(
) bool { ) bool {
for _, ev := range authEvents { for _, ev := range authEvents {
membership, err := ev.Membership() membership, err := ev.Membership()
if err != nil || membership != "join" { if err != nil || membership != gomatrixserverlib.Join {
continue continue
} }

View file

@ -23,13 +23,6 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// Membership values
// TODO: Factor these out somewhere sensible?
const join = "join"
const leave = "leave"
const invite = "invite"
const ban = "ban"
// updateMembership updates the current membership and the invites for each // updateMembership updates the current membership and the invites for each
// user affected by a change in the current state of the room. // user affected by a change in the current state of the room.
// Returns a list of output events to write to the kafka log to inform the // Returns a list of output events to write to the kafka log to inform the
@ -91,8 +84,8 @@ func updateMembership(
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var err error var err error
// Default the membership to Leave if no event was added or removed. // Default the membership to Leave if no event was added or removed.
oldMembership := leave oldMembership := gomatrixserverlib.Leave
newMembership := leave newMembership := gomatrixserverlib.Leave
if remove != nil { if remove != nil {
oldMembership, err = remove.Membership() oldMembership, err = remove.Membership()
@ -106,7 +99,7 @@ func updateMembership(
return nil, err return nil, err
} }
} }
if oldMembership == newMembership && newMembership != join { if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
// If the membership is the same then nothing changed and we can return // If the membership is the same then nothing changed and we can return
// immediately, unless it's a Join update (e.g. profile update). // immediately, unless it's a Join update (e.g. profile update).
return updates, nil return updates, nil
@ -118,11 +111,11 @@ func updateMembership(
} }
switch newMembership { switch newMembership {
case invite: case gomatrixserverlib.Invite:
return updateToInviteMembership(mu, add, updates) return updateToInviteMembership(mu, add, updates)
case join: case gomatrixserverlib.Join:
return updateToJoinMembership(mu, add, updates) return updateToJoinMembership(mu, add, updates)
case leave, ban: case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
return updateToLeaveMembership(mu, add, newMembership, updates) return updateToLeaveMembership(mu, add, newMembership, updates)
default: default:
panic(fmt.Errorf( panic(fmt.Errorf(
@ -183,7 +176,7 @@ func updateToJoinMembership(
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{ orie := api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
Membership: join, Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetUserID: *add.StateKey(),
} }

View file

@ -359,7 +359,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
return nil, err return nil, err
} }
if membership == "join" { if membership == gomatrixserverlib.Join {
events = append(events, event) events = append(events, event)
} }
} }

View file

@ -23,7 +23,7 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx
-- For deleting old invites -- For deleting old invites
CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx
ON syncapi_invite_events(target_user_id, id); ON syncapi_invite_events (event_id);
` `
const insertInviteEventSQL = "" + const insertInviteEventSQL = "" +

View file

@ -235,6 +235,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
device authtypes.Device, device authtypes.Device,
fromPos, toPos int64, fromPos, toPos int64,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool,
res *types.Response, res *types.Response,
) ([]string, error) { ) ([]string, error) {
txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
@ -248,14 +249,18 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions.
// This works out what the 'state' key should be for each room as well as which membership block // This works out what the 'state' key should be for each room as well as which membership block
// to put the room into. // to put the room into.
deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID) var deltas []stateDelta
var joinedRoomIDs []string
if !wantFullState {
deltas, joinedRoomIDs, err = d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID)
} else {
deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync(ctx, &device, txn, fromPos, toPos, device.UserID)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
joinedRoomIDs := make([]string, 0, len(deltas))
for _, delta := range deltas { for _, delta := range deltas {
joinedRoomIDs = append(joinedRoomIDs, delta.roomID)
err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
if err != nil { if err != nil {
return nil, err return nil, err
@ -332,19 +337,20 @@ func (d *SyncServerDatasource) IncrementalSync(
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.SyncPosition, fromPos, toPos types.SyncPosition,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos) nextBatchPos := fromPos.WithUpdates(toPos)
res := types.NewResponse(nextBatchPos) res := types.NewResponse(nextBatchPos)
var joinedRoomIDs []string var joinedRoomIDs []string
var err error var err error
if fromPos.PDUPosition != toPos.PDUPosition { if fromPos.PDUPosition != toPos.PDUPosition || wantFullState {
joinedRoomIDs, err = d.addPDUDeltaToResponse( joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, res, ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res,
) )
} else { } else {
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(
ctx, nil, device.UserID, "join", ctx, nil, device.UserID, gomatrixserverlib.Join,
) )
} }
if err != nil { if err != nil {
@ -393,7 +399,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
res = types.NewResponse(toPos) res = types.NewResponse(toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil { if err != nil {
return return
} }
@ -571,7 +577,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
res *types.Response, res *types.Response,
) error { ) error {
endPos := toPos endPos := toPos
if delta.membershipPos > 0 && delta.membership == "leave" { if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave {
// make sure we don't leak recent events after the leave event. // make sure we don't leak recent events after the leave event.
// TODO: History visibility makes this somewhat complex to handle correctly. For example: // TODO: History visibility makes this somewhat complex to handle correctly. For example:
// TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
@ -589,38 +595,42 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
recentEvents := streamEventsToEvents(device, recentStreamEvents) recentEvents := streamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
var prevPDUPos int64
if len(recentEvents) == 0 {
if len(delta.stateEvents) == 0 {
// Don't bother appending empty room entries // Don't bother appending empty room entries
if len(recentEvents) == 0 && len(delta.stateEvents) == 0 {
return nil return nil
} }
// If full_state=true and since is already up to date, then we'll have
// state events but no recent events.
prevPDUPos = toPos - 1
} else {
prevPDUPos = recentStreamEvents[0].streamPosition - 1
}
if prevPDUPos <= 0 {
prevPDUPos = 1
}
switch delta.membership { switch delta.membership {
case "join": case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
// Use the short form of batch token for prev_batch // Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
} else {
// Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = "1"
}
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.roomID] = *jr res.Rooms.Join[delta.roomID] = *jr
case "leave": case gomatrixserverlib.Leave:
fallthrough // transitions to leave are the same as ban fallthrough // transitions to leave are the same as ban
case "ban": case gomatrixserverlib.Ban:
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
// Use the short form of batch token for prev_batch // Use the short form of batch token for prev_batch
lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
} else {
// Use the short form of batch token for prev_batch
lr.Timeline.PrevBatch = "1"
}
lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -716,10 +726,14 @@ func (d *SyncServerDatasource) fetchMissingStateEvents(
return events, nil return events, nil
} }
// getStateDeltas returns the state deltas between fromPos and toPos,
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
func (d *SyncServerDatasource) getStateDeltas( func (d *SyncServerDatasource) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos int64, userID string, fromPos, toPos int64, userID string,
) ([]stateDelta, error) { ) ([]stateDelta, []string, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
// - Get membership list changes for this user in this sync response // - Get membership list changes for this user in this sync response
// - For each room which has membership list changes: // - For each room which has membership list changes:
@ -733,11 +747,11 @@ func (d *SyncServerDatasource) getStateDeltas(
// get all the state events ever between these two positions // get all the state events ever between these two positions
stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos) stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
@ -748,16 +762,12 @@ func (d *SyncServerDatasource) getStateDeltas(
// the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
// the timeline. // the timeline.
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership == "join" { if membership == gomatrixserverlib.Join {
// send full room state down instead of a delta // send full room state down instead of a delta
var allState []gomatrixserverlib.Event var s []streamEvent
allState, err = d.roomstate.selectCurrentState(ctx, txn, roomID) s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
s := make([]streamEvent, len(allState))
for i := 0; i < len(s); i++ {
s[i] = streamEvent{Event: allState[i], streamPosition: 0}
} }
state[roomID] = s state[roomID] = s
continue // we'll add this room in when we do joined rooms continue // we'll add this room in when we do joined rooms
@ -775,19 +785,92 @@ func (d *SyncServerDatasource) getStateDeltas(
} }
// Add in currently joined rooms // Add in currently joined rooms
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: "join", membership: gomatrixserverlib.Join,
stateEvents: streamEventsToEvents(device, state[joinedRoomID]), stateEvents: streamEventsToEvents(device, state[joinedRoomID]),
roomID: joinedRoomID, roomID: joinedRoomID,
}) })
} }
return deltas, nil return deltas, joinedRoomIDs, nil
}
// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms.
func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos int64, userID string,
) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
if err != nil {
return nil, nil, err
}
// Use a reasonable initial capacity
deltas := make([]stateDelta, 0, len(joinedRoomIDs))
// Add full states for all joined rooms
for _, joinedRoomID := range joinedRoomIDs {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID)
if stateErr != nil {
return nil, nil, stateErr
}
deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join,
stateEvents: streamEventsToEvents(device, s),
roomID: joinedRoomID,
})
}
// Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos)
if err != nil {
return nil, nil, err
}
state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
if err != nil {
return nil, nil, err
}
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
deltas = append(deltas, stateDelta{
membership: membership,
membershipPos: ev.streamPosition,
stateEvents: streamEventsToEvents(device, stateStreamEvents),
roomID: roomID,
})
}
break
}
}
}
return deltas, joinedRoomIDs, nil
}
func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]streamEvent, error) {
allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID)
if err != nil {
return nil, err
}
s := make([]streamEvent, len(allState))
for i := 0; i < len(s); i++ {
s[i] = streamEvent{Event: allState[i], streamPosition: 0}
}
return s, nil
} }
// streamEventsToEvents converts streamEvent to Event. If device is non-nil and // streamEventsToEvents converts streamEvent to Event. If device is non-nil and

View file

@ -93,16 +93,16 @@ func (n *Notifier) OnNewEvent(
} else { } else {
// Keep the joined user map up-to-date // Keep the joined user map up-to-date
switch membership { switch membership {
case "invite": case gomatrixserverlib.Invite:
usersToNotify = append(usersToNotify, targetUserID) usersToNotify = append(usersToNotify, targetUserID)
case "join": case gomatrixserverlib.Join:
// Manually append the new user's ID so they get notified // Manually append the new user's ID so they get notified
// along all members in the room // along all members in the room
usersToNotify = append(usersToNotify, targetUserID) usersToNotify = append(usersToNotify, targetUserID)
n.addJoinedUser(ev.RoomID(), targetUserID) n.addJoinedUser(ev.RoomID(), targetUserID)
case "leave": case gomatrixserverlib.Leave:
fallthrough fallthrough
case "ban": case gomatrixserverlib.Ban:
n.removeJoinedUser(ev.RoomID(), targetUserID) n.removeJoinedUser(ev.RoomID(), targetUserID)
} }
} }
@ -185,6 +185,7 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) {
// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, // fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true,
// a stream will be made for this user if one doesn't exist and it will be returned. This // a stream will be made for this user if one doesn't exist and it will be returned. This
// function does not wait for data to be available on the stream. // function does not wait for data to be available on the stream.
// NB: Callers should have locked the mutex before calling this function.
func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream {
stream, ok := n.userStreams[userID] stream, ok := n.userStreams[userID]
if !ok && makeIfNotExists { if !ok && makeIfNotExists {

View file

@ -143,7 +143,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
wg.Done() wg.Done()
}() }()
stream := n.fetchUserStream(bob, true) stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
@ -171,7 +171,7 @@ func TestNewInviteEventForUser(t *testing.T) {
wg.Done() wg.Done()
}() }()
stream := n.fetchUserStream(bob, true) stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter)
@ -199,7 +199,7 @@ func TestEDUWakeup(t *testing.T) {
wg.Done() wg.Done()
}() }()
stream := n.fetchUserStream(bob, true) stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 1) waitForBlocking(stream, 1)
n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU)
@ -230,7 +230,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
go poll() go poll()
go poll() go poll()
stream := n.fetchUserStream(bob, true) stream := lockedFetchUserStream(n, bob)
waitForBlocking(stream, 3) waitForBlocking(stream, 3)
n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter)
@ -266,14 +266,14 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
} }
leaveWG.Done() leaveWG.Done()
}() }()
bobStream := n.fetchUserStream(bob, true) bobStream := lockedFetchUserStream(n, bob)
waitForBlocking(bobStream, 1) waitForBlocking(bobStream, 1)
n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter)
leaveWG.Wait() leaveWG.Wait()
// send an event into the room. Make sure alice gets it. Bob should not. // send an event into the room. Make sure alice gets it. Bob should not.
var aliceWG sync.WaitGroup var aliceWG sync.WaitGroup
aliceStream := n.fetchUserStream(alice, true) aliceStream := lockedFetchUserStream(n, alice)
aliceWG.Add(1) aliceWG.Add(1)
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter))
@ -328,6 +328,15 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
} }
} }
// lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock.
// A new stream is made if it doesn't exist already.
func lockedFetchUserStream(n *Notifier, userID string) *UserStream {
n.streamLock.Lock()
defer n.streamLock.Unlock()
return n.fetchUserStream(userID, true)
}
func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest { func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest {
return syncRequest{ return syncRequest{
device: authtypes.Device{UserID: userID}, device: authtypes.Device{UserID: userID},

View file

@ -65,8 +65,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
currPos := rp.notifier.CurrentPosition() currPos := rp.notifier.CurrentPosition()
// If this is an initial sync or timeout=0 we return immediately if shouldReturnImmediately(syncReq) {
if syncReq.since == nil || syncReq.timeout == 0 {
syncData, err = rp.currentSyncForUser(*syncReq, currPos) syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil { if err != nil {
return httputil.LogThenError(req, err) return httputil.LogThenError(req, err)
@ -135,7 +134,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncP
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit) res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState)
} }
if err != nil { if err != nil {
@ -216,3 +215,10 @@ func (rp *RequestPool) appendAccountData(
return data, nil return data, nil
} }
// shouldReturnImmediately returns whether the /sync request is an initial sync,
// or timeout=0, or full_state=true, in any of the cases the request should
// return immediately.
func shouldReturnImmediately(syncReq *syncRequest) bool {
return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState
}

View file

@ -149,3 +149,20 @@ Typing events appear in incremental sync
Typing events appear in gapped sync Typing events appear in gapped sync
Inbound federation of state requires event_id as a mandatory paramater Inbound federation of state requires event_id as a mandatory paramater
Inbound federation of state_ids requires event_id as a mandatory paramater Inbound federation of state_ids requires event_id as a mandatory paramater
POST /register returns the same device_id as that in the request
POST /login returns the same device_id as that in the request
POST /createRoom with creation content
User can create and send/receive messages in a room with version 1
POST /createRoom ignores attempts to set the room version via creation_content
Inbound federation rejects remote attempts to join local users to rooms
Inbound federation rejects remote attempts to kick local users to rooms
An event which redacts itself should be ignored
A pair of events which redact each other should be ignored
Full state sync includes joined rooms
Can add tag
Can remove tag
Can list tags for a room
Tags appear in an initial v2 /sync
Newly updated tags appear in an incremental v2 /sync
Deleted tags appear in an incremental v2 /sync
/event/ on non world readable room does not work