diff --git a/userapi/api/api.go b/userapi/api/api.go index c9f211b27..cf119a959 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -46,6 +46,7 @@ type UserInternalAPI interface { QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryPresenceForUser(ctx context.Context, req *QueryPresenceForUserRequest, res *QueryPresenceForUserResponse) error QueryPresenceAfter(ctx context.Context, req *QueryPresenceAfterRequest, res *QueryPresenceAfterResponse) error + QueryMaxPresenceID(ctx context.Context, req *QueryMaxPresenceIDRequest, res *QueryMaxPresenceIDResponse) error } type PerformKeyBackupRequest struct { @@ -375,6 +376,14 @@ type QueryPresenceAfterResponse struct { Presences []QueryPresenceForUserResponse } +// QueryMaxPresenceIDRequest is the request for QueryMaxPresenceIDRequest +type QueryMaxPresenceIDRequest struct{} + +// QueryMaxPresenceIDResponse is the request for QueryMaxPresenceIDRequest +type QueryMaxPresenceIDResponse struct { + ID int64 +} + // Device represents a client's device (mobile, web, etc) type Device struct { ID string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index b2c39451c..06d4ec383 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -620,3 +620,12 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB } res.Keys = result } + +func (a *UserInternalAPI) QueryMaxPresenceID(ctx context.Context, req *api.QueryMaxPresenceIDRequest, res *api.QueryMaxPresenceIDResponse) error { + id, err := a.PresenceDB.GetMaxPresenceID(ctx) + if err != nil { + return err + } + res.ID = id + return nil +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 869532b8f..124374b63 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -48,6 +48,7 @@ const ( QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" QueryPresenceForUserPath = "/userapi/queryPresenceForUser" QueryPresenceAfterPath = "/userapi/queryPresenceAfter" + QueryMaxPresenceID = "/userapi/queryMaxPresenceID" QueryKeyBackupPath = "/userapi/queryKeyBackup" ) @@ -276,3 +277,11 @@ func (h *httpUserInternalAPI) QueryPresenceAfter(ctx context.Context, req *api.Q apiURL := h.apiURL + QueryPresenceAfterPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) QueryMaxPresenceID(ctx context.Context, req *api.QueryMaxPresenceIDRequest, res *api.QueryMaxPresenceIDResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMaxPresenceID") + defer span.Finish() + + apiURL := h.apiURL + QueryMaxPresenceID + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 15861b7d1..c3cc17b16 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -273,4 +273,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryMaxPresenceID, + httputil.MakeInternalAPI("queryMaxPresenceID", func(req *http.Request) util.JSONResponse { + request := api.QueryMaxPresenceIDRequest{} + response := api.QueryMaxPresenceIDResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryMaxPresenceID(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/presence/interface.go b/userapi/storage/presence/interface.go index c6d7839d5..2883a5191 100644 --- a/userapi/storage/presence/interface.go +++ b/userapi/storage/presence/interface.go @@ -38,4 +38,5 @@ type Database interface { ctx context.Context, pos int64, ) (presence []api.OutputPresenceData, err error) + GetMaxPresenceID(ctx context.Context) (pos int64, err error) } diff --git a/userapi/storage/presence/postgres/presence_table.go b/userapi/storage/presence/postgres/presence_table.go index 9da71b5ee..81a82cc48 100644 --- a/userapi/storage/presence/postgres/presence_table.go +++ b/userapi/storage/presence/postgres/presence_table.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal/sqlutil" - types2 "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/types" ) @@ -49,7 +48,7 @@ const upsertPresenceSQL = "" + " (user_id, presence, status_msg, last_active_ts)" + " VALUES ($1, $2, $3, $4)" + " ON CONFLICT (user_id)" + - " DO UPDATE SET id = currval('presence_presence_id')," + + " DO UPDATE SET id = nextval('presence_presence_id')," + " presence = $2, status_msg = COALESCE($3, p.status_msg), last_active_ts = $4" + " RETURNING id" @@ -59,7 +58,7 @@ const selectPresenceForUserSQL = "" + " WHERE user_id = $1 LIMIT 1" const selectMaxPresenceSQL = "" + - "SELECT MAX(id) FROM presence_presences" + "SELECT COALESCE(MAX(id), 0) FROM presence_presences" const selectPresenceAfter = "" + " SELECT id, user_id, presence, status_msg, last_active_ts" + @@ -125,7 +124,7 @@ func (p *presenceStatements) GetPresenceForUser( return } -func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types2.StreamingToken, err error) { +func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos int64, err error) { stmt := sqlutil.TxStmt(txn, p.selectMaxPresenceStmt) err = stmt.QueryRowContext(ctx).Scan(&pos) return diff --git a/userapi/storage/presence/postgres/storage.go b/userapi/storage/presence/postgres/storage.go index 8bb30c9fa..9932b74d6 100644 --- a/userapi/storage/presence/postgres/storage.go +++ b/userapi/storage/presence/postgres/storage.go @@ -60,3 +60,7 @@ func (d *Database) GetPresenceForUser(ctx context.Context, userID string) (api.O func (d *Database) GetPresenceAfter(ctx context.Context, pos int64) (presence []api.OutputPresenceData, err error) { return d.presence.GetPresenceAfter(ctx, nil, pos) } + +func (d *Database) GetMaxPresenceID(ctx context.Context) (pos int64, err error) { + return d.presence.GetMaxPresenceID(ctx, nil) +} diff --git a/userapi/storage/presence/sqlite3/presence_table.go b/userapi/storage/presence/sqlite3/presence_table.go index cbc1379ad..66c0395fb 100644 --- a/userapi/storage/presence/sqlite3/presence_table.go +++ b/userapi/storage/presence/sqlite3/presence_table.go @@ -61,9 +61,13 @@ const selectPresenceAfter = "" + " FROM presence_presences" + " WHERE id > $1" +const selectMaxPresenceSQL = "" + + "SELECT COALESCE(MAX(id), 0) FROM presence_presences" + type presenceStatements struct { upsertPresenceStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt + selectMaxPresenceStmt *sql.Stmt selectPresenceAfterStmt *sql.Stmt } @@ -82,6 +86,9 @@ func (p *presenceStatements) prepare(db *sql.DB) (err error) { if p.selectPresenceAfterStmt, err = db.Prepare(selectPresenceAfter); err != nil { return } + if p.selectMaxPresenceStmt, err = db.Prepare(selectMaxPresenceSQL); err != nil { + return + } return } @@ -130,3 +137,9 @@ func (p *presenceStatements) GetPresenceAfter( } return presences, rows.Err() } + +func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos int64, err error) { + stmt := sqlutil.TxStmt(txn, p.selectMaxPresenceStmt) + err = stmt.QueryRowContext(ctx).Scan(&pos) + return +} diff --git a/userapi/storage/presence/sqlite3/storage.go b/userapi/storage/presence/sqlite3/storage.go index 7ba8619d8..7a969c3c2 100644 --- a/userapi/storage/presence/sqlite3/storage.go +++ b/userapi/storage/presence/sqlite3/storage.go @@ -74,3 +74,7 @@ func (d *Database) GetPresenceForUser(ctx context.Context, userID string) (api.O func (d *Database) GetPresenceAfter(ctx context.Context, pos int64) (presence []api.OutputPresenceData, err error) { return d.presence.GetPresenceAfter(ctx, nil, pos) } + +func (d *Database) GetMaxPresenceID(ctx context.Context) (pos int64, err error) { + return d.presence.GetMaxPresenceID(ctx, nil) +}