diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 5d81547eb..55ed27a41 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -186,7 +186,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( }, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - stateFilter.Limit, "", + stateFilter.Limit, FilterOrderNone, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index 0f6bd681f..0faf5297a 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -7,6 +7,14 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" ) +type FilterOrder int + +const ( + FilterOrderNone = iota + FilterOrderAsc + FilterOrderDesc +) + // prepareWithFilters returns a prepared statement with the // relevant filters included. It also includes an []interface{} // list of all the relevant parameters to pass straight to @@ -18,7 +26,7 @@ import ( func prepareWithFilters( db *sql.DB, txn *sql.Tx, query string, params []interface{}, senders, notsenders, types, nottypes []string, - limit int, order string, + limit int, order FilterOrder, ) (*sql.Stmt, []interface{}, error) { offset := len(params) if count := len(senders); count > 0 { @@ -45,8 +53,11 @@ func prepareWithFilters( params, offset = append(params, v), offset+1 } } - if order != "" { - query += " ORDER BY id " + order + switch order { + case FilterOrderAsc: + query += " ORDER BY id ASC" + case FilterOrderDesc: + query += " ORDER BY id DESC" } query += fmt.Sprintf(" LIMIT $%d", offset+1) params = append(params, limit) diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 4debc53f6..019aba8b3 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -150,7 +150,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( }, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - stateFilter.Limit, "ASC", + stateFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -273,11 +273,14 @@ func (s *outputRoomEventsStatements) InsertEvent( if len(addState) > 0 { addStateJSON, err = json.Marshal(addState) } + if err != nil { + return 0, fmt.Errorf("json.Marshal(addState): %w", err) + } if len(removeState) > 0 { removeStateJSON, err = json.Marshal(removeState) } if err != nil { - return 0, err + return 0, fmt.Errorf("json.Marshal(removeState): %w", err) } streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn) @@ -323,7 +326,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - eventFilter.Limit+1, "DESC", + eventFilter.Limit+1, FilterOrderDesc, ) if err != nil { return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) @@ -371,7 +374,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( }, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, - eventFilter.Limit, "ASC", + eventFilter.Limit, FilterOrderAsc, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 58eb531c0..09a62e3dd 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -49,21 +49,22 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat filterQuery := req.URL.Query().Get("filter") if filterQuery != "" { if filterQuery[0] == '{' { - // attempt to parse the timeline limit at least + // Parse the filter from the query string if err := json.Unmarshal([]byte(filterQuery), &filter); err != nil { return nil, fmt.Errorf("json.Unmarshal: %w", err) } } else { - // attempt to load the filter ID + // Try to load the filter from the database localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return nil, err + return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err == nil { - filter = *f + if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") + return nil, fmt.Errorf("syncDB.GetFilter: %w", err) } else { - panic(err) + filter = *f } } }