diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ffd84e5bb..5d0717d7a 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -142,16 +142,12 @@ func (r *Queryer) QueryMissingAuthPrevEvents( response.RoomExists = !info.IsStub response.RoomVersion = info.RoomVersion - joined, err := r.DB.GetLocalServerInRoom(ctx, info.RoomNID) - if err != nil { - return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) - } - response.RoomJoined = joined - - // If we're not joined to the room then there's no point in hitting - // the database further to work out which events we're missing. - if !joined { - return nil + if response.RoomExists { + joined, err := r.DB.GetLocalServerInRoom(ctx, info.RoomNID) + if err != nil { + return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) + } + response.RoomJoined = joined } for _, authEventID := range request.AuthEventIDs { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 888918161..9102f26a3 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -130,8 +130,7 @@ var selectKnownUsersSQL = "" + // is expensive. The presence of a single row from this query suggests we're still in the // room, no rows returned suggests we aren't. const selectLocalServerInRoomSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE target_local = true AND membership_nid = 3 AND room_nid = $1 " + - "FETCH FIRST 1 ROWS ONLY" + "SELECT room_nid FROM roomserver_membership WHERE target_local = true AND membership_nid = $1 AND room_nid = $2 LIMIT 1" type membershipStatements struct { insertMembershipStmt *sql.Stmt @@ -337,17 +336,14 @@ func (s *membershipStatements) UpdateForgetMembership( } func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { - rows, err := s.selectLocalServerInRoomStmt.QueryContext(ctx, roomNID) + var nid types.RoomNID + err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil } return false, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectLocalServerInRoom: rows.close() failed") - found := false - for rows.Next() { - found = true - } - return found, rows.Err() + found := nid > 0 + return found, nil } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index e4ad11892..82babe0d2 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -106,7 +106,7 @@ var selectKnownUsersSQL = "" + // is expensive. The presence of a single row from this query suggests we're still in the // room, no rows returned suggests we aren't. const selectLocalServerInRoomSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = 3 AND room_nid = $1 LIMIT 1" + "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = $1 AND room_nid = $2 LIMIT 1" type membershipStatements struct { db *sql.DB @@ -316,17 +316,14 @@ func (s *membershipStatements) UpdateForgetMembership( } func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { - rows, err := s.selectLocalServerInRoomStmt.QueryContext(ctx, roomNID) + var nid types.RoomNID + err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil } return false, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectLocalServerInRoom: rows.close() failed") - found := false - for rows.Next() { - found = true - } - return found, rows.Err() + found := nid > 0 + return found, nil }