Merge remote-tracking branch 'origin/master' into markjh/invitesV

This commit is contained in:
Mark Haines 2017-09-20 15:27:08 +01:00
commit a020c41eea
35 changed files with 120 additions and 150 deletions

View file

@ -13,12 +13,21 @@ rm -f sytest/server-*/*.log sytest/results.tap
./jenkins/prepare-dendrite.sh
if [ ! -d "sytest" ]; then
git clone https://github.com/matrix-org/sytest.git --depth 1 --branch dendrite
else
git -C sytest fetch --depth 1 origin dendrite
git -C sytest reset --hard FETCH_HEAD
git clone https://github.com/matrix-org/sytest.git --depth 1 --branch master
fi
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
git -C sytest fetch --depth 1 origin "${GIT_BRANCH}" || {
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git -C sytest fetch --depth 1 origin develop
}
git -C sytest reset --hard FETCH_HEAD
./sytest/jenkins/prep_sytest_for_postgres.sh
./sytest/jenkins/install_and_run.sh \

View file

@ -10,6 +10,8 @@
"ineffassign",
"gas",
"misspell",
"errcheck"
"errcheck",
"vet",
"goconst"
]
}

View file

@ -15,6 +15,9 @@
"gas",
"misspell",
"unparam",
"errcheck"
"errcheck",
"vet",
"megacheck",
"goconst"
]
}

View file

@ -51,14 +51,10 @@ const selectMembershipsByLocalpartSQL = "" +
const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id = ANY($1)"
const updateMembershipByEventIDSQL = "" +
"UPDATE account_memberships SET event_id = $2 WHERE event_id = $1"
type membershipStatements struct {
deleteMembershipsByEventIDsStmt *sql.Stmt
insertMembershipStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt
updateMembershipByEventIDStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -75,9 +71,6 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return
}
if s.updateMembershipByEventIDStmt, err = db.Prepare(updateMembershipByEventIDSQL); err != nil {
return
}
return
}
@ -120,12 +113,3 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
return
}
func (s *membershipStatements) updateMembershipByEventID(
ctx context.Context, oldEventID string, newEventID string,
) (err error) {
_, err = s.updateMembershipByEventIDStmt.ExecContext(
ctx, oldEventID, newEventID,
)
return
}

View file

@ -67,10 +67,7 @@ func (d *Database) CreateDevice(
}
dev, err = d.devices.insertDevice(ctx, txn, deviceID, localpart, accessToken)
if err != nil {
return err
}
return nil
return err
})
return
}

View file

@ -45,9 +45,6 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
m.Key = sarama.StringEncoder(userID)
m.Value = sarama.ByteEncoder(value)
if _, _, err := p.Producer.SendMessage(&m); err != nil {
return err
}
return nil
_, _, err = p.Producer.SendMessage(&m)
return err
}

View file

@ -57,9 +57,6 @@ func (p *UserUpdateProducer) SendUpdate(
}
m.Value = sarama.ByteEncoder(value)
if _, _, err := p.Producer.SendMessage(&m); err != nil {
return err
}
return nil
_, _, err = p.Producer.SendMessage(&m)
return err
}

View file

@ -163,12 +163,12 @@ func createRoom(req *http.Request, device *authtypes.Device,
{"m.room.member", userID, membershipContent},
{"m.room.power_levels", "", common.InitialPowerLevelsContent(userID)},
// TODO: m.room.canonical_alias
{"m.room.join_rules", "", common.JoinRulesContent{"public"}}, // FIXME: Allow this to be changed
{"m.room.history_visibility", "", common.HistoryVisibilityContent{"joined"}}, // FIXME: Allow this to be changed
{"m.room.guest_access", "", common.GuestAccessContent{"can_join"}}, // FIXME: Allow this to be changed
{"m.room.join_rules", "", common.JoinRulesContent{JoinRule: "public"}}, // FIXME: Allow this to be changed
{"m.room.history_visibility", "", common.HistoryVisibilityContent{HistoryVisibility: "joined"}}, // FIXME: Allow this to be changed
{"m.room.guest_access", "", common.GuestAccessContent{GuestAccess: "can_join"}}, // FIXME: Allow this to be changed
// TODO: Other initial state items
{"m.room.name", "", common.NameContent{r.Name}}, // FIXME: Only send the name event if a name is supplied, to avoid sending a false room name removal event
{"m.room.topic", "", common.TopicContent{r.Topic}},
{"m.room.name", "", common.NameContent{Name: r.Name}}, // FIXME: Only send the name event if a name is supplied, to avoid sending a false room name removal event
{"m.room.topic", "", common.TopicContent{Topic: r.Topic}},
// TODO: invite events
// TODO: 3pid invite events
// TODO: m.room.aliases

View file

@ -67,7 +67,7 @@ func main() {
keyRing := gomatrixserverlib.KeyRing{
KeyFetchers: []gomatrixserverlib.KeyFetcher{
// TODO: Use perspective key fetchers for production.
&gomatrixserverlib.DirectKeyFetcher{federation.Client},
&gomatrixserverlib.DirectKeyFetcher{Client: federation.Client},
},
KeyDatabase: keyDB,
}

View file

@ -80,7 +80,9 @@ func main() {
aliasAPI.SetupHTTP(http.DefaultServeMux)
http.DefaultServeMux.Handle("/metrics", prometheus.Handler())
// This is deprecated, but prometheus are still arguing on what to replace
// it with. Alternatively we could set it up manually.
http.DefaultServeMux.Handle("/metrics", prometheus.Handler()) // nolint: staticcheck, megacheck
log.Info("Started room server on ", cfg.Listen.RoomServer)

View file

@ -190,7 +190,7 @@ func getMediaURI(host, endpoint, query string, components []string) string {
func testUpload(host, filePath string) {
fmt.Printf("==TESTING== upload %v to %v\n", filePath, host)
file, err := os.Open(filePath)
defer file.Close() // nolint: errcheck
defer file.Close() // nolint: errcheck, staticcheck, megacheck
if err != nil {
panic(err)
}

View file

@ -14,7 +14,7 @@
package main
// nolint: varcheck, deadcode
// nolint: varcheck, deadcode, unused, megacheck
const (
i0StateRoomCreate = iota
i1StateAliceJoin

View file

@ -413,7 +413,7 @@ func fingerprintPEM(data []byte) *gomatrixserverlib.TLSFingerprint {
}
if certDERBlock.Type == "CERTIFICATE" {
digest := sha256.Sum256(certDERBlock.Bytes)
return &gomatrixserverlib.TLSFingerprint{digest[:]}
return &gomatrixserverlib.TLSFingerprint{SHA256: digest[:]}
}
}
}

View file

@ -52,6 +52,7 @@ func MakeFedAPI(
// SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics
// listener.
func SetupHTTPAPI(servMux *http.ServeMux, apiMux *mux.Router) {
servMux.Handle("/metrics", prometheus.Handler())
// This is deprecated.
servMux.Handle("/metrics", prometheus.Handler()) // nolint: megacheck, staticcheck
servMux.Handle("/api/", http.StripPrefix("/api", apiMux))
}

View file

@ -97,7 +97,8 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
return nil, err
}
r := gomatrixserverlib.PublicKeyRequest{
gomatrixserverlib.ServerName(serverName), gomatrixserverlib.KeyID(keyID),
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
results[r] = serverKeys
}
@ -115,10 +116,7 @@ func (s *serverKeyStatements) upsertServerKeys(
string(request.ServerName), string(request.KeyID), nameAndKeyID(request),
int64(keys.ValidUntilTS), keyJSON,
)
if err != nil {
return err
}
return nil
return err
}
func nameAndKeyID(request gomatrixserverlib.PublicKeyRequest) string {

View file

@ -47,6 +47,7 @@ type PartitionOffsetStatements struct {
// Prepare converts the raw SQL statements into prepared statements.
// Takes a prefix to prepend to the table name used to store the partition offsets.
// This allows multiple components to share the same database schema.
// nolint: safesql
func (s *PartitionOffsetStatements) Prepare(db *sql.DB, prefix string) (err error) {
_, err = db.Exec(strings.Replace(partitionOffsetsSchema, "${prefix}", prefix, -1))
if err != nil {

View file

@ -135,17 +135,14 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
err = keyOut.Close()
})()
if err = pem.Encode(keyOut, &pem.Block{
err = pem.Encode(keyOut, &pem.Block{
Type: "MATRIX PRIVATE KEY",
Headers: map[string]string{
"Key-ID": "ed25519:" + base64.RawStdEncoding.EncodeToString(data[:3]),
},
Bytes: data[3:],
}); err != nil {
return err
}
return nil
})
return err
}
const certificateDuration = time.Hour * 24 * 365 * 10
@ -191,12 +188,9 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
return err
}
defer keyOut.Close() // nolint: errcheck
if err = pem.Encode(keyOut, &pem.Block{
err = pem.Encode(keyOut, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
}); err != nil {
return err
}
return nil
})
return err
}

View file

@ -44,7 +44,7 @@ func localKeys(cfg config.Dendrite, validUntil time.Time) (*gomatrixserverlib.Se
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
cfg.Matrix.KeyID: {
gomatrixserverlib.Base64String(publicKey),
Key: gomatrixserverlib.Base64String(publicKey),
},
}

View file

@ -347,9 +347,5 @@ func fillDisplayName(
// Use the m.room.third_party_invite event to fill the "displayname" and
// update the m.room.member event's content with it
content.ThirdPartyInvite.DisplayName = thirdPartyInviteContent.DisplayName
if err := builder.SetContent(content); err != nil {
return err
}
return nil
return builder.SetContent(content)
}

View file

@ -55,7 +55,7 @@ func GetPathFromBase64Hash(base64Hash types.Base64Hash, absBasePath config.Path)
// check if the absolute absBasePath is a prefix of the absolute filePath
// if so, no directory escape has occurred and the filePath is valid
// Note: absBasePath is already absolute
if strings.HasPrefix(filePath, string(absBasePath)) == false {
if !strings.HasPrefix(filePath, string(absBasePath)) {
return "", fmt.Errorf("Invalid filePath (not within absBasePath %v): %v", absBasePath, filePath)
}

View file

@ -63,28 +63,24 @@ func SelectThumbnail(desired types.ThumbnailSize, thumbnails []*types.ThumbnailM
bestFit := newThumbnailFitness()
for _, thumbnail := range thumbnails {
if desired.ResizeMethod == "scale" && thumbnail.ThumbnailSize.ResizeMethod != "scale" {
if desired.ResizeMethod == types.Scale && thumbnail.ThumbnailSize.ResizeMethod != types.Scale {
continue
}
fitness := calcThumbnailFitness(thumbnail.ThumbnailSize, thumbnail.MediaMetadata, desired)
if isBetter := fitness.betterThan(bestFit, desired.ResizeMethod == "crop"); isBetter {
if isBetter := fitness.betterThan(bestFit, desired.ResizeMethod == types.Crop); isBetter {
bestFit = fitness
chosenThumbnail = thumbnail
}
}
for _, thumbnailSize := range thumbnailSizes {
if desired.ResizeMethod == "scale" && thumbnailSize.ResizeMethod != "scale" {
if desired.ResizeMethod == types.Scale && thumbnailSize.ResizeMethod != types.Scale {
continue
}
fitness := calcThumbnailFitness(types.ThumbnailSize(thumbnailSize), nil, desired)
if isBetter := fitness.betterThan(bestFit, desired.ResizeMethod == "crop"); isBetter {
if isBetter := fitness.betterThan(bestFit, desired.ResizeMethod == types.Crop); isBetter {
bestFit = fitness
chosenThumbnailSize = &types.ThumbnailSize{
Width: thumbnailSize.Width,
Height: thumbnailSize.Height,
ResizeMethod: thumbnailSize.ResizeMethod,
}
chosenThumbnailSize = (*types.ThumbnailSize)(&thumbnailSize)
}
}

View file

@ -149,14 +149,14 @@ func createThumbnail(src types.Path, img image.Image, config types.ThumbnailSize
}
start := time.Now()
width, height, err := adjustSize(dst, img, config.Width, config.Height, config.ResizeMethod == "crop", logger)
width, height, err := adjustSize(dst, img, config.Width, config.Height, config.ResizeMethod == types.Crop, logger)
if err != nil {
return false, err
}
logger.WithFields(log.Fields{
"ActualWidth": width,
"ActualHeight": height,
"processTime": time.Now().Sub(start),
"processTime": time.Since(start),
}).Info("Generated thumbnail")
stat, err := os.Stat(string(dst))

View file

@ -102,3 +102,9 @@ type ActiveThumbnailGeneration struct {
// The string key is a thumbnail file path
PathToResult map[string]*ThumbnailGenerationResult
}
// Crop indicates we should crop the thumbnail on resize
const Crop = "crop"
// Scale indicates we should scale the thumbnail on resize
const Scale = "scale"

View file

@ -155,7 +155,7 @@ func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSON
// Validate validates the downloadRequest fields
func (r *downloadRequest) Validate() *util.JSONResponse {
if mediaIDRegex.MatchString(string(r.MediaMetadata.MediaID)) == false {
if !mediaIDRegex.MatchString(string(r.MediaMetadata.MediaID)) {
return &util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("mediaId must be a non-empty string using only characters in %v", mediaIDCharacters)),
@ -179,9 +179,9 @@ func (r *downloadRequest) Validate() *util.JSONResponse {
}
// Default method to scale if not set
if r.ThumbnailSize.ResizeMethod == "" {
r.ThumbnailSize.ResizeMethod = "scale"
r.ThumbnailSize.ResizeMethod = types.Scale
}
if r.ThumbnailSize.ResizeMethod != "crop" && r.ThumbnailSize.ResizeMethod != "scale" {
if r.ThumbnailSize.ResizeMethod != types.Crop && r.ThumbnailSize.ResizeMethod != types.Scale {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown("method must be one of crop or scale"),
@ -236,7 +236,7 @@ func (r *downloadRequest) respondFromLocalFile(
return nil, errors.Wrap(err, "failed to get file path from metadata")
}
file, err := os.Open(filePath)
defer file.Close() // nolint: errcheck
defer file.Close() // nolint: errcheck, staticcheck, megacheck
if err != nil {
return nil, errors.Wrap(err, "failed to open file")
}
@ -337,7 +337,7 @@ func (r *downloadRequest) getThumbnailFile(
thumbnail, thumbnailSize = thumbnailer.SelectThumbnail(r.ThumbnailSize, thumbnails, thumbnailSizes)
// If dynamicThumbnails is true and we are not over-loaded then we would have generated what was requested above.
// So we don't try to generate a pre-generated thumbnail here.
if thumbnailSize != nil && dynamicThumbnails == false {
if thumbnailSize != nil && !dynamicThumbnails {
r.Logger.WithFields(log.Fields{
"Width": thumbnailSize.Width,
"Height": thumbnailSize.Height,
@ -525,7 +525,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
// If the file is a duplicate (has the same hash as an existing file) then
// there is valid metadata in the database for that file. As such we only
// remove the file if it is not a duplicate.
if duplicate == false {
if !duplicate {
finalDir := filepath.Dir(string(finalPath))
fileutils.RemoveDir(types.Path(finalDir), r.Logger)
}

View file

@ -226,7 +226,7 @@ func (r *uploadRequest) storeFileAndMetadata(tmpDir types.Path, absBasePath conf
// If the file is a duplicate (has the same hash as an existing file) then
// there is valid metadata in the database for that file. As such we only
// remove the file if it is not a duplicate.
if duplicate == false {
if !duplicate {
fileutils.RemoveDir(types.Path(path.Dir(string(finalPath))), r.Logger)
}
return &util.JSONResponse{

View file

@ -102,10 +102,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *publicRoomReq) *util.JSO
request.Since = httpReq.FormValue("since")
return nil
} else if httpReq.Method == "POST" {
if reqErr := httputil.UnmarshalJSONRequest(httpReq, request); reqErr != nil {
return reqErr
}
return nil
return httputil.UnmarshalJSONRequest(httpReq, request)
}
return &util.JSONResponse{

View file

@ -134,6 +134,7 @@ type publicRoomsStatements struct {
updateRoomAttributeStmts map[string]*sql.Stmt
}
// nolint: safesql
func (s *publicRoomsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(publicRoomsSchema)
if err != nil {

View file

@ -217,11 +217,7 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
var inputRes api.InputRoomEventsResponse
// Send the request
if err := r.InputAPI.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
return err
}
return nil
return r.InputAPI.InputRoomEvents(ctx, &inputReq, &inputRes)
}
// SetupHTTP adds the RoomserverAliasAPI handlers to the http.ServeMux.

View file

@ -15,24 +15,29 @@
package input
import (
"github.com/matrix-org/dendrite/roomserver/types"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
var list []types.StateEntry
for i := int64(0); i < entries; i++ {
list = append(list, types.StateEntry{types.StateKeyTuple{
types.EventTypeNID(i),
types.EventStateKeyNID(i),
}, types.EventNID(i)})
list = append(list, types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: types.EventTypeNID(i),
EventStateKeyNID: types.EventStateKeyNID(i),
},
EventNID: types.EventNID(i),
})
}
for i := 0; i < b.N; i++ {
entryMap := stateEntryMap(list)
for j := int64(0); j < lookups; j++ {
entryMap.lookup(types.StateKeyTuple{
types.EventTypeNID(j), types.EventStateKeyNID(j),
EventTypeNID: types.EventTypeNID(j),
EventStateKeyNID: types.EventStateKeyNID(j),
})
}
}
@ -56,9 +61,9 @@ func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
func TestStateEntryMap(t *testing.T) {
entryMap := stateEntryMap([]types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 3}, 2},
{types.StateKeyTuple{2, 1}, 3},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 3}, EventNID: 2},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 1}, EventNID: 3},
})
testCases := []struct {
@ -78,7 +83,7 @@ func TestStateEntryMap(t *testing.T) {
}
for _, testCase := range testCases {
keyTuple := types.StateKeyTuple{testCase.inputTypeNID, testCase.inputStateKey}
keyTuple := types.StateKeyTuple{EventTypeNID: testCase.inputTypeNID, EventStateKeyNID: testCase.inputStateKey}
gotEventNID, gotOK := entryMap.lookup(keyTuple)
if testCase.wantOK != gotOK {
t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)

View file

@ -102,8 +102,7 @@ type latestEventsUpdater struct {
}
func (u *latestEventsUpdater) doUpdateLatestEvents() error {
var prevEvents []gomatrixserverlib.EventReference
prevEvents = u.event.PrevEvents()
prevEvents := u.event.PrevEvents()
oldLatest := u.updater.LatestEvents()
u.lastEventIDSent = u.updater.LastEventIDSent()
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
@ -194,10 +193,7 @@ func (u *latestEventsUpdater) latestState() error {
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots(
u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
)
if err != nil {
return err
}
return nil
return err
}
func calculateLatest(
@ -211,7 +207,7 @@ func calculateLatest(
for _, l := range oldLatest {
keep := true
for _, prevEvent := range prevEvents {
if l.EventID == prevEvent.EventID && bytes.Compare(l.EventSHA256, prevEvent.EventSHA256) == 0 {
if l.EventID == prevEvent.EventID && bytes.Equal(l.EventSHA256, prevEvent.EventSHA256) {
// This event can be removed from the latest events cause we've found an event that references it.
// (If an event is referenced by another event then it can't be one of the latest events in the room
// because we have an event that comes after it)

View file

@ -23,6 +23,13 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
// Membership values
// TODO: Factor these out somewhere sensible?
const join = "join"
const leave = "leave"
const invite = "invite"
const ban = "ban"
// updateMembership updates the current membership and the invites for each
// user affected by a change in the current state of the room.
// Returns a list of output events to write to the kafka log to inform the
@ -83,9 +90,9 @@ func updateMembership(
updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
var err error
// Default the membership to "leave" if no event was added or removed.
old := "leave"
new := "leave"
// Default the membership to Leave if no event was added or removed.
old := leave
new := leave
if remove != nil {
old, err = remove.Membership()
@ -99,9 +106,9 @@ func updateMembership(
return nil, err
}
}
if old == new && new != "join" {
if old == new && new != join {
// If the membership is the same then nothing changed and we can return
// immediately, unless it's a "join" update (e.g. profile update).
// immediately, unless it's a Join update (e.g. profile update).
return updates, nil
}
@ -111,11 +118,11 @@ func updateMembership(
}
switch new {
case "invite":
case invite:
return updateToInviteMembership(mu, add, updates)
case "join":
case join:
return updateToJoinMembership(mu, add, updates)
case "leave", "ban":
case leave, ban:
return updateToLeaveMembership(mu, add, new, updates)
default:
panic(fmt.Errorf(
@ -176,7 +183,7 @@ func updateToJoinMembership(
for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{
EventID: eventID,
Membership: "join",
Membership: join,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}

View file

@ -15,8 +15,9 @@
package state
import (
"github.com/matrix-org/dendrite/roomserver/types"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestFindDuplicateStateKeys(t *testing.T) {
@ -25,18 +26,18 @@ func TestFindDuplicateStateKeys(t *testing.T) {
Want []types.StateEntry
}{{
Input: []types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 1}, 2},
{types.StateKeyTuple{2, 2}, 3},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 2},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 2}, EventNID: 3},
},
Want: []types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 1}, 2},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 2},
},
}, {
Input: []types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 2}, 2},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 2}, EventNID: 2},
},
Want: nil,
}}

View file

@ -60,10 +60,6 @@ const bulkSelectEventStateKeyNIDSQL = "" +
"SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" +
" WHERE event_state_key = ANY($1)"
const selectEventStateKeySQL = "" +
"SELECT event_state_key FROM roomserver_event_state_keys" +
" WHERE event_state_key_nid = $1"
// Bulk lookup from numeric ID to string state key for that state key.
// Takes an array of strings as the query parameter.
const bulkSelectEventStateKeySQL = "" +
@ -73,7 +69,6 @@ const bulkSelectEventStateKeySQL = "" +
type eventStateKeyStatements struct {
insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyStmt *sql.Stmt
}
@ -86,7 +81,6 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
return statementList{
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
{&s.selectEventStateKeyStmt, selectEventStateKeySQL},
{&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
{&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
}.prepare(db)
@ -133,15 +127,6 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
return result, nil
}
func (s *eventStateKeyStatements) selectEventStateKey(
ctx context.Context, txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID,
) (string, error) {
var eventStateKey string
stmt := common.TxStmt(txn, s.selectEventStateKeyStmt)
err := stmt.QueryRowContext(ctx, eventStateKeyNID).Scan(&eventStateKey)
return eventStateKey, err
}
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {

View file

@ -25,6 +25,7 @@ type statementList []struct {
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
// nolint: safesql
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {

View file

@ -497,9 +497,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(
if len(stateEvents) != len(missing) {
return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
}
for _, e := range stateEvents {
events = append(events, e)
}
events = append(events, stateEvents...)
return events, nil
}