From 4406c85c98e685348f2858b89485ad539dff8405 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 15 Jun 2023 14:48:21 +0200 Subject: [PATCH] Merge main, fix issues --- mediaapi/routing/download.go | 5 +++-- roomserver/internal/query/query.go | 2 +- roomserver/storage/postgres/user_room_keys_table.go | 3 ++- roomserver/storage/shared/storage_test.go | 2 +- roomserver/storage/sqlite3/user_room_keys_table.go | 3 ++- roomserver/storage/tables/user_room_keys_table_test.go | 5 +++-- syncapi/routing/search_test.go | 2 +- syncapi/storage/postgres/memberships_table.go | 2 +- syncapi/storage/sqlite3/memberships_table.go | 2 +- syncapi/storage/tables/memberships_test.go | 2 ++ 10 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index e9f161a3c..8fb1b6534 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -341,6 +341,7 @@ func (r *downloadRequest) addDownloadFilenameToHeaders( } if len(filename) == 0 { + w.Header().Set("Content-Disposition", "attachment") return nil } @@ -376,13 +377,13 @@ func (r *downloadRequest) addDownloadFilenameToHeaders( // that would otherwise be parsed as a control character in the // Content-Disposition header w.Header().Set("Content-Disposition", fmt.Sprintf( - `inline; filename=%s%s%s`, + `attachment; filename=%s%s%s`, quote, unescaped, quote, )) } else { // For UTF-8 filenames, we quote always, as that's the standard w.Header().Set("Content-Disposition", fmt.Sprintf( - `inline; filename*=utf-8''%s`, + `attachment; filename*=utf-8''%s`, url.QueryEscape(unescaped), )) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 1b5637266..918619e5e 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -1029,7 +1029,7 @@ func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, } if userKeys, ok := result[roomID]; ok { - if userID, ok := userKeys[string(bytes)]; ok { + if userID, ok := userKeys[string(senderID)]; ok { return spec.NewUserID(userID, true) } } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index dbb4af34a..202b0abc1 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const userRoomKeysSchema = ` @@ -145,7 +146,7 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { return nil, err } - result[string(publicKey)] = userRoomKeyPair + result[spec.Base64Bytes(publicKey).Encode()] = userRoomKeyPair } return result, rows.Err() } diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index c7b915c7d..612e4ef06 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -183,7 +183,7 @@ func TestUserRoomKeys(t *testing.T) { assert.NoError(t, err) wantKeys := map[spec.RoomID]map[string]string{ *roomID: { - string(key.Public().(ed25519.PublicKey)): userID.String(), + spec.Base64Bytes(key.Public().(ed25519.PublicKey)).Encode(): userID.String(), }, } assert.Equal(t, wantKeys, userIDs) diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 84c8b54ec..d58b8ac3f 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib/spec" ) const userRoomKeysSchema = ` @@ -159,7 +160,7 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { return nil, err } - result[string(publicKey)] = userRoomKeyPair + result[spec.Base64Bytes(publicKey).Encode()] = userRoomKeyPair } return result, rows.Err() } diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 8802a3c6e..2809771b4 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ed255192 "golang.org/x/crypto/ed25519" ) @@ -101,8 +102,8 @@ func TestUserRoomKeysTable(t *testing.T) { assert.NotNil(t, gotKeys) wantKeys := map[string]types.UserRoomKeyPair{ - string(key2.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID}, - string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2}, + string(spec.Base64Bytes(key2.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID}, + string(spec.Base64Bytes(key3.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID2}, } assert.Equal(t, wantKeys, gotKeys) diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 939650a95..905a9a1ac 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -230,7 +230,7 @@ func TestSearch(t *testing.T) { stateEvents = append(stateEvents, x) stateEventIDs = append(stateEventIDs, x.EventID()) } - + x.StateKeyResolved = x.StateKey() sp, err = db.WriteEvent(processCtx.Context(), x, stateEvents, stateEventIDs, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared) assert.NoError(t, err) if x.Type() != "m.room.message" { diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 6bf9f61a1..09b47432b 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -109,7 +109,7 @@ func (s *membershipsStatements) UpsertMembership( _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, event.RoomID(), - event.UserID.String(), + event.StateKeyResolved, membership, event.EventID(), streamPos, diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 6218ea660..a9e880d2a 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -112,7 +112,7 @@ func (s *membershipsStatements) UpsertMembership( _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( ctx, event.RoomID(), - event.UserID.String(), + event.StateKeyResolved, membership, event.EventID(), streamPos, diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index 4afa2ac5b..a421a9772 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -80,6 +80,7 @@ func TestMembershipsTable(t *testing.T) { defer cancel() for _, ev := range userEvents { + ev.StateKeyResolved = ev.StateKey() if err := table.UpsertMembership(ctx, nil, ev, types.StreamPosition(ev.Depth()), 1); err != nil { t.Fatalf("failed to upsert membership: %s", err) } @@ -134,6 +135,7 @@ func testUpsert(t *testing.T, ctx context.Context, table tables.Memberships, mem ev := room.CreateAndInsert(t, user, spec.MRoomMember, map[string]interface{}{ "membership": spec.Join, }, test.WithStateKey(user.ID)) + ev.StateKeyResolved = ev.StateKey() // Insert the same event again, but with different positions, which should get updated if err = table.UpsertMembership(ctx, nil, ev, 2, 2); err != nil { t.Fatalf("failed to upsert membership: %s", err)