Review comments

This commit is contained in:
Neil Alexander 2021-01-19 17:44:38 +00:00
parent 317226de98
commit e153f90e7a
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
4 changed files with 29 additions and 14 deletions

View file

@ -186,7 +186,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
}, },
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
stateFilter.Limit, "", stateFilter.Limit, FilterOrderNone,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -7,6 +7,14 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
) )
type FilterOrder int
const (
FilterOrderNone = iota
FilterOrderAsc
FilterOrderDesc
)
// prepareWithFilters returns a prepared statement with the // prepareWithFilters returns a prepared statement with the
// relevant filters included. It also includes an []interface{} // relevant filters included. It also includes an []interface{}
// list of all the relevant parameters to pass straight to // list of all the relevant parameters to pass straight to
@ -18,7 +26,7 @@ import (
func prepareWithFilters( func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{}, db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, senders, notsenders, types, nottypes []string,
limit int, order string, limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) { ) (*sql.Stmt, []interface{}, error) {
offset := len(params) offset := len(params)
if count := len(senders); count > 0 { if count := len(senders); count > 0 {
@ -45,8 +53,11 @@ func prepareWithFilters(
params, offset = append(params, v), offset+1 params, offset = append(params, v), offset+1
} }
} }
if order != "" { switch order {
query += " ORDER BY id " + order case FilterOrderAsc:
query += " ORDER BY id ASC"
case FilterOrderDesc:
query += " ORDER BY id DESC"
} }
query += fmt.Sprintf(" LIMIT $%d", offset+1) query += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit) params = append(params, limit)

View file

@ -150,7 +150,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
}, },
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
stateFilter.Limit, "ASC", stateFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -273,11 +273,14 @@ func (s *outputRoomEventsStatements) InsertEvent(
if len(addState) > 0 { if len(addState) > 0 {
addStateJSON, err = json.Marshal(addState) addStateJSON, err = json.Marshal(addState)
} }
if err != nil {
return 0, fmt.Errorf("json.Marshal(addState): %w", err)
}
if len(removeState) > 0 { if len(removeState) > 0 {
removeStateJSON, err = json.Marshal(removeState) removeStateJSON, err = json.Marshal(removeState)
} }
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("json.Marshal(removeState): %w", err)
} }
streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn) streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
@ -323,7 +326,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit+1, "DESC", eventFilter.Limit+1, FilterOrderDesc,
) )
if err != nil { if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -371,7 +374,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit, "ASC", eventFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -49,21 +49,22 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" { if filterQuery != "" {
if filterQuery[0] == '{' { if filterQuery[0] == '{' {
// attempt to parse the timeline limit at least // Parse the filter from the query string
if err := json.Unmarshal([]byte(filterQuery), &filter); err != nil { if err := json.Unmarshal([]byte(filterQuery), &filter); err != nil {
return nil, fmt.Errorf("json.Unmarshal: %w", err) return nil, fmt.Errorf("json.Unmarshal: %w", err)
} }
} else { } else {
// attempt to load the filter ID // Try to load the filter from the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, err return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err == nil { if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil {
filter = *f util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} else { } else {
panic(err) filter = *f
} }
} }
} }