Implement forgetting about rooms (#1572)

* Add basic storage methods

* Add internal api handler

* Add check for forgotten room

* Add /rooms/{roomID}/forget endpoint

* Add missing rsAPI method

* Remove unused parameters

* Add passing tests

Signed-off-by: Till Faelligen <tfaelligen@gmail.com>

* Add missing file

* Add postgres migration

* Add sqlite migration

* Use Forgetter to forget room

* Remove empty line

* Update HTTP status codes

It looks like the spec calls for these to be 400, rather than 403: https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-rooms-roomid-forget

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
S7evinK 2020-11-05 11:19:23 +01:00 committed by GitHub
parent 2ce2112ddb
commit eccd0d2c1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 543 additions and 136 deletions

View file

@ -407,3 +407,47 @@ func checkMemberInRoom(ctx context.Context, rsAPI api.RoomserverInternalAPI, use
} }
return nil return nil
} }
func SendForget(
req *http.Request, device *userapi.Device,
roomID string, rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse {
ctx := req.Context()
logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID)
var membershipRes api.QueryMembershipForUserResponse
membershipReq := api.QueryMembershipForUserRequest{
RoomID: roomID,
UserID: device.UserID,
}
err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
if err != nil {
logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user")
return jsonerror.InternalServerError()
}
if membershipRes.IsInRoom {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Forbidden("user is still a member of the room"),
}
}
if !membershipRes.HasBeenInRoom {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Forbidden("user did not belong to room"),
}
}
request := api.PerformForgetRequest{
RoomID: roomID,
UserID: device.UserID,
}
response := api.PerformForgetResponse{}
if err := rsAPI.PerformForget(ctx, &request, &response); err != nil {
logger.WithError(err).Error("PerformForget: unable to forget room")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -709,6 +709,19 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/forget",
httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendForget(req, device, vars["roomID"], rsAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/devices", r0mux.Handle("/devices",
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, userAPI, device) return GetDevicesByLocalpart(req, userAPI, device)

View file

@ -84,6 +84,10 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, resp *api.PerformForgetResponse) error {
return nil
}
func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {}
func (t *testRoomserverAPI) InputRoomEvents( func (t *testRoomserverAPI) InputRoomEvents(

View file

@ -147,6 +147,9 @@ type RoomserverInternalAPI interface {
response *PerformBackfillResponse, response *PerformBackfillResponse,
) error ) error
// PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
// Asks for the default room version as preferred by the server. // Asks for the default room version as preferred by the server.
QueryRoomVersionCapabilities( QueryRoomVersionCapabilities(
ctx context.Context, ctx context.Context,

View file

@ -194,6 +194,16 @@ func (t *RoomserverInternalAPITrace) PerformBackfill(
return err return err
} }
func (t *RoomserverInternalAPITrace) PerformForget(
ctx context.Context,
req *PerformForgetRequest,
res *PerformForgetResponse,
) error {
err := t.Impl.PerformForget(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("PerformForget req=%+v res=%+v", js(req), js(res))
return err
}
func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities( func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities(
ctx context.Context, ctx context.Context,
req *QueryRoomVersionCapabilitiesRequest, req *QueryRoomVersionCapabilitiesRequest,

View file

@ -159,3 +159,11 @@ type PerformPublishResponse struct {
// If non-nil, the publish request failed. Contains more information why it failed. // If non-nil, the publish request failed. Contains more information why it failed.
Error *PerformError Error *PerformError
} }
// PerformForgetRequest is a request to PerformForget
type PerformForgetRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
}
type PerformForgetResponse struct{}

View file

@ -140,7 +140,9 @@ type QueryMembershipForUserResponse struct {
// True if the user is in room. // True if the user is in room.
IsInRoom bool `json:"is_in_room"` IsInRoom bool `json:"is_in_room"`
// The current membership // The current membership
Membership string Membership string `json:"membership"`
// True if the user asked to forget this room.
IsRoomForgotten bool `json:"is_room_forgotten"`
} }
// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom
@ -160,6 +162,8 @@ type QueryMembershipsForRoomResponse struct {
// True if the user has been in room before and has either stayed in it or // True if the user has been in room before and has either stayed in it or
// left it. // left it.
HasBeenInRoom bool `json:"has_been_in_room"` HasBeenInRoom bool `json:"has_been_in_room"`
// True if the user asked to forget this room.
IsRoomForgotten bool `json:"is_room_forgotten"`
} }
// QueryServerJoinedToRoomRequest is a request to QueryServerJoinedToRoom // QueryServerJoinedToRoomRequest is a request to QueryServerJoinedToRoom

View file

@ -26,6 +26,7 @@ type RoomserverInternalAPI struct {
*perform.Leaver *perform.Leaver
*perform.Publisher *perform.Publisher
*perform.Backfiller *perform.Backfiller
*perform.Forgetter
DB storage.Database DB storage.Database
Cfg *config.RoomServer Cfg *config.RoomServer
Producer sarama.SyncProducer Producer sarama.SyncProducer
@ -112,6 +113,9 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen
// than trying random servers // than trying random servers
PreferServers: r.PerspectiveServerNames, PreferServers: r.PerspectiveServerNames,
} }
r.Forgetter = &perform.Forgetter{
DB: r.DB,
}
} }
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(
@ -143,3 +147,11 @@ func (r *RoomserverInternalAPI) PerformLeave(
} }
return r.WriteOutputEvents(req.RoomID, outputEvents) return r.WriteOutputEvents(req.RoomID, outputEvents)
} }
func (r *RoomserverInternalAPI) PerformForget(
ctx context.Context,
req *api.PerformForgetRequest,
resp *api.PerformForgetResponse,
) error {
return r.Forgetter.PerformForget(ctx, req, resp)
}

View file

@ -0,0 +1,35 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package perform
import (
"context"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
)
type Forgetter struct {
DB storage.Database
}
// PerformForget implements api.RoomServerQueryAPI
func (f *Forgetter) PerformForget(
ctx context.Context,
request *api.PerformForgetRequest,
response *api.PerformForgetResponse,
) error {
return f.DB.ForgetRoom(ctx, request.UserID, request.RoomID, true)
}

View file

@ -86,7 +86,7 @@ func (r *Inviter) PerformInvite(
var isAlreadyJoined bool var isAlreadyJoined bool
if info != nil { if info != nil {
_, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
if err != nil { if err != nil {
return nil, fmt.Errorf("r.DB.GetMembership: %w", err) return nil, fmt.Errorf("r.DB.GetMembership: %w", err)
} }

View file

@ -204,11 +204,13 @@ func (r *Queryer) QueryMembershipForUser(
return fmt.Errorf("QueryMembershipForUser: unknown room %s", request.RoomID) return fmt.Errorf("QueryMembershipForUser: unknown room %s", request.RoomID)
} }
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID)
if err != nil { if err != nil {
return err return err
} }
response.IsRoomForgotten = isRoomforgotten
if membershipEventNID == 0 { if membershipEventNID == 0 {
response.HasBeenInRoom = false response.HasBeenInRoom = false
return nil return nil
@ -241,11 +243,13 @@ func (r *Queryer) QueryMembershipsForRoom(
return err return err
} }
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender)
if err != nil { if err != nil {
return err return err
} }
response.IsRoomForgotten = isRoomforgotten
if membershipEventNID == 0 { if membershipEventNID == 0 {
response.HasBeenInRoom = false response.HasBeenInRoom = false
response.JoinEvents = nil response.JoinEvents = nil

View file

@ -31,6 +31,7 @@ const (
RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformLeavePath = "/roomserver/performLeave"
RoomserverPerformBackfillPath = "/roomserver/performBackfill" RoomserverPerformBackfillPath = "/roomserver/performBackfill"
RoomserverPerformPublishPath = "/roomserver/performPublish" RoomserverPerformPublishPath = "/roomserver/performPublish"
RoomserverPerformForgetPath = "/roomserver/performForget"
// Query operations // Query operations
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
@ -492,3 +493,12 @@ func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformForgetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -251,6 +251,20 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(
RoomserverPerformForgetPath,
httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse {
var request api.PerformForgetRequest
var response api.PerformForgetResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.PerformForget(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryRoomVersionCapabilitiesPath, RoomserverQueryRoomVersionCapabilitiesPath,
httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse {

View file

@ -126,7 +126,7 @@ type Database interface {
// in this room, along a boolean set to true if the user is still in this room, // in this room, along a boolean set to true if the user is still in this room,
// false if not. // false if not.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error)
// Lookup the membership event numeric IDs for all user that are or have // Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if // been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true. // joinOnly is set to true.
@ -158,4 +158,6 @@ type Database interface {
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownRooms returns a list of all rooms we know about. // GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error) GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
} }

View file

@ -0,0 +1,47 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -60,13 +60,15 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
-- a federated one. This is an optimisation for resetting state on federated -- a federated one. This is an optimisation for resetting state on federated
-- room joins. -- room joins.
target_local BOOLEAN NOT NULL DEFAULT false, target_local BOOLEAN NOT NULL DEFAULT false,
forgotten BOOLEAN NOT NULL DEFAULT FALSE,
UNIQUE (room_nid, target_nid) UNIQUE (room_nid, target_nid)
); );
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
@ -76,37 +78,41 @@ const insertMembershipSQL = "" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" + const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid FROM roomserver_membership" + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" + const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" + const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true" " AND target_local = true and forgotten = false"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" " WHERE room_nid = $1 and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" + const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" + " WHERE room_nid = $1" +
" AND target_local = true" " AND target_local = true and forgotten = false"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE"
const updateMembershipSQL = "" + const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5, forgotten = $6" +
" WHERE room_nid = $1 AND target_nid = $2"
const updateMembershipForgetRoom = "" +
"UPDATE roomserver_membership SET forgotten = $3" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
const selectRoomsWithMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will // joined to. Since this information is used to populate the user directory, we will
@ -130,6 +136,7 @@ type membershipStatements struct {
selectRoomsWithMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt
} }
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
@ -151,9 +158,15 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
}.Prepare(db) }.Prepare(db)
} }
func (s *membershipStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(membershipSchema)
return err
}
func (s *membershipStatements) InsertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
@ -177,10 +190,10 @@ func (s *membershipStatements) SelectMembershipForUpdate(
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
@ -238,12 +251,11 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
func (s *membershipStatements) UpdateMembership( func (s *membershipStatements) UpdateMembership(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool,
eventNID types.EventNID,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, forgotten,
) )
return err return err
} }
@ -305,3 +317,14 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, roomNID, targetUserNID, forget,
)
return err
}

View file

@ -18,12 +18,13 @@ package postgres
import ( import (
"database/sql" "database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
) )
@ -33,7 +34,6 @@ type Database struct {
} }
// Open a postgres database. // Open a postgres database.
// nolint: gocyclo
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database var d Database
var db *sql.DB var db *sql.DB
@ -41,61 +41,82 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
if db, err = sqlutil.Open(dbProperties); err != nil { if db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
// Create tables before executing migrations so we don't fail if the table is missing,
// and THEN prepare statements so we don't fail due to referencing new columns
ms := membershipStatements{}
if err := ms.execSchema(db); err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadAddForgottenColumn(m)
if err := m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err := d.prepare(db, cache); err != nil {
return nil, err
}
return &d, nil
}
// nolint: gocyclo
func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err error) {
eventStateKeys, err := NewPostgresEventStateKeysTable(db) eventStateKeys, err := NewPostgresEventStateKeysTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
eventTypes, err := NewPostgresEventTypesTable(db) eventTypes, err := NewPostgresEventTypesTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
eventJSON, err := NewPostgresEventJSONTable(db) eventJSON, err := NewPostgresEventJSONTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
events, err := NewPostgresEventsTable(db) events, err := NewPostgresEventsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
rooms, err := NewPostgresRoomsTable(db) rooms, err := NewPostgresRoomsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
transactions, err := NewPostgresTransactionsTable(db) transactions, err := NewPostgresTransactionsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
stateBlock, err := NewPostgresStateBlockTable(db) stateBlock, err := NewPostgresStateBlockTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
stateSnapshot, err := NewPostgresStateSnapshotTable(db) stateSnapshot, err := NewPostgresStateSnapshotTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
roomAliases, err := NewPostgresRoomAliasesTable(db) roomAliases, err := NewPostgresRoomAliasesTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
prevEvents, err := NewPostgresPreviousEventsTable(db) prevEvents, err := NewPostgresPreviousEventsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
invites, err := NewPostgresInvitesTable(db) invites, err := NewPostgresInvitesTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
membership, err := NewPostgresMembershipTable(db) membership, err := NewPostgresMembershipTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
published, err := NewPostgresPublishedTable(db) published, err := NewPostgresPublishedTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
redactions, err := NewPostgresRedactionsTable(db) redactions, err := NewPostgresRedactionsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db, DB: db,
@ -116,5 +137,5 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions, RedactionsTable: redactions,
} }
return &d, nil return nil
} }

View file

@ -101,9 +101,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
} }
if u.membership != tables.MembershipStateInvite { if u.membership != tables.MembershipStateInvite {
if err = u.d.MembershipTable.UpdateMembership( if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil {
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
} }
} }
@ -139,10 +137,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
} }
if u.membership != tables.MembershipStateJoin || isUpdate { if u.membership != tables.MembershipStateJoin || isUpdate {
if err = u.d.MembershipTable.UpdateMembership( if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil {
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateJoin, nIDs[eventID],
); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
} }
} }
@ -176,10 +171,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
} }
if u.membership != tables.MembershipStateLeaveOrBan { if u.membership != tables.MembershipStateLeaveOrBan {
if err = u.d.MembershipTable.UpdateMembership( if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil {
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
} }
} }

View file

@ -258,30 +258,28 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
}) })
} }
func (d *Database) GetMembership( func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) {
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
var requestSenderUserNID types.EventStateKeyNID var requestSenderUserNID types.EventStateKeyNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID)
return err return err
}) })
if err != nil { if err != nil {
return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err) return 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err)
} }
senderMembershipEventNID, senderMembership, err := senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
d.MembershipTable.SelectMembershipFromRoomAndTarget( d.MembershipTable.SelectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID, ctx, roomNID, requestSenderUserNID,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
return 0, false, nil return 0, false, false, nil
} else if err != nil { } else if err != nil {
return return
} }
return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil
} }
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
@ -992,6 +990,25 @@ func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx) return d.RoomsTable.SelectRoomIDs(ctx)
} }
// ForgetRoom sets a users room to forgotten
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID})
if err != nil {
return err
}
if len(roomNIDs) > 1 {
return fmt.Errorf("expected one room, got %d", len(roomNIDs))
}
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
if err != nil {
return err
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget)
})
}
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
// it should live in this package! // it should live in this package!

View file

@ -0,0 +1,82 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
forgotten BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
INSERT
INTO roomserver_membership (
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
) SELECT
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
FROM roomserver_membership_tmp
;
DROP TABLE roomserver_membership_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
INSERT
INTO roomserver_membership (
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
) SELECT
room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local
FROM roomserver_membership_tmp
;
DROP TABLE roomserver_membership_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -36,13 +36,15 @@ const membershipSchema = `
membership_nid INTEGER NOT NULL DEFAULT 1, membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0, event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false, target_local BOOLEAN NOT NULL DEFAULT false,
forgotten BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid) UNIQUE (room_nid, target_nid)
); );
` `
var selectJoinedUsersSetForRoomsSQL = "" + var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
@ -52,37 +54,41 @@ const insertMembershipSQL = "" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" + const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid FROM roomserver_membership" + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" + const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" + const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true" " AND target_local = true and forgotten = false"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" " WHERE room_nid = $1 and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" + const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" + " WHERE room_nid = $1" +
" AND target_local = true" " AND target_local = true and forgotten = false"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
const updateMembershipSQL = "" + const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" +
" WHERE room_nid = $4 AND target_nid = $5" " WHERE room_nid = $5 AND target_nid = $6"
const updateMembershipForgetRoom = "" +
"UPDATE roomserver_membership SET forgotten = $1" +
" WHERE room_nid = $2 AND target_nid = $3"
const selectRoomsWithMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will // joined to. Since this information is used to populate the user directory, we will
@ -106,16 +112,13 @@ type membershipStatements struct {
selectRoomsWithMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt
} }
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
s := &membershipStatements{ s := &membershipStatements{
db: db, db: db,
} }
_, err := db.Exec(membershipSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{ return s, shared.StatementList{
{&s.insertMembershipStmt, insertMembershipSQL}, {&s.insertMembershipStmt, insertMembershipSQL},
@ -128,9 +131,15 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipStmt, updateMembershipSQL}, {&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
}.Prepare(db) }.Prepare(db)
} }
func (s *membershipStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(membershipSchema)
return err
}
func (s *membershipStatements) InsertMembership( func (s *membershipStatements) InsertMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
@ -155,10 +164,10 @@ func (s *membershipStatements) SelectMembershipForUpdate(
func (s *membershipStatements) SelectMembershipFromRoomAndTarget( func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, err error) { ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
ctx, roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID) ).Scan(&membership, &eventNID, &forgotten)
return return
} }
@ -216,13 +225,12 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
func (s *membershipStatements) UpdateMembership( func (s *membershipStatements) UpdateMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool,
eventNID types.EventNID,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID,
) )
return err return err
} }
@ -285,3 +293,14 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, forget, roomNID, targetUserNID,
)
return err
}

View file

@ -19,127 +19,138 @@ import (
"context" "context"
"database/sql" "database/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
) )
// A Database is used to store room events and stream offsets. // A Database is used to store room events and stream offsets.
type Database struct { type Database struct {
shared.Database shared.Database
events tables.Events
eventJSON tables.EventJSON
eventTypes tables.EventTypes
eventStateKeys tables.EventStateKeys
rooms tables.Rooms
transactions tables.Transactions
prevEvents tables.PreviousEvents
invites tables.Invites
membership tables.Membership
db *sql.DB
writer sqlutil.Writer
} }
// Open a sqlite database. // Open a sqlite database.
// nolint: gocyclo
func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database var d Database
var db *sql.DB
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil { if db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err return nil, err
} }
d.writer = sqlutil.NewExclusiveWriter()
//d.db.Exec("PRAGMA journal_mode=WAL;") //db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA read_uncommitted = true;") //db.Exec("PRAGMA read_uncommitted = true;")
// FIXME: We are leaking connections somewhere. Setting this to 2 will eventually // FIXME: We are leaking connections somewhere. Setting this to 2 will eventually
// cause the roomserver to be unresponsive to new events because something will // cause the roomserver to be unresponsive to new events because something will
// acquire the global mutex and never unlock it because it is waiting for a connection // acquire the global mutex and never unlock it because it is waiting for a connection
// which it will never obtain. // which it will never obtain.
d.db.SetMaxOpenConns(20) db.SetMaxOpenConns(20)
d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) // Create tables before executing migrations so we don't fail if the table is missing,
if err != nil { // and THEN prepare statements so we don't fail due to referencing new columns
ms := membershipStatements{}
if err := ms.execSchema(db); err != nil {
return nil, err return nil, err
} }
d.eventTypes, err = NewSqliteEventTypesTable(d.db) m := sqlutil.NewMigrations()
if err != nil { deltas.LoadAddForgottenColumn(m)
if err := m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }
d.eventJSON, err = NewSqliteEventJSONTable(d.db) if err := d.prepare(db, cache); err != nil {
if err != nil {
return nil, err return nil, err
} }
d.events, err = NewSqliteEventsTable(d.db)
return &d, nil
}
// nolint: gocyclo
func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
var err error
eventStateKeys, err := NewSqliteEventStateKeysTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
d.rooms, err = NewSqliteRoomsTable(d.db) eventTypes, err := NewSqliteEventTypesTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
d.transactions, err = NewSqliteTransactionsTable(d.db) eventJSON, err := NewSqliteEventJSONTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
stateBlock, err := NewSqliteStateBlockTable(d.db) events, err := NewSqliteEventsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) rooms, err := NewSqliteRoomsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
d.prevEvents, err = NewSqlitePrevEventsTable(d.db) transactions, err := NewSqliteTransactionsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
roomAliases, err := NewSqliteRoomAliasesTable(d.db) stateBlock, err := NewSqliteStateBlockTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
d.invites, err = NewSqliteInvitesTable(d.db) stateSnapshot, err := NewSqliteStateSnapshotTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
d.membership, err = NewSqliteMembershipTable(d.db) prevEvents, err := NewSqlitePrevEventsTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
published, err := NewSqlitePublishedTable(d.db) roomAliases, err := NewSqliteRoomAliasesTable(db)
if err != nil { if err != nil {
return nil, err return err
} }
redactions, err := NewSqliteRedactionsTable(d.db) invites, err := NewSqliteInvitesTable(db)
if err != nil { if err != nil {
return nil, err return err
}
membership, err := NewSqliteMembershipTable(db)
if err != nil {
return err
}
published, err := NewSqlitePublishedTable(db)
if err != nil {
return err
}
redactions, err := NewSqliteRedactionsTable(db)
if err != nil {
return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: db,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: sqlutil.NewExclusiveWriter(),
EventsTable: d.events, EventsTable: events,
EventTypesTable: d.eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: d.eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: d.eventJSON, EventJSONTable: eventJSON,
RoomsTable: d.rooms, RoomsTable: rooms,
TransactionsTable: d.transactions, TransactionsTable: transactions,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: d.prevEvents, PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: d.invites, InvitesTable: invites,
MembershipTable: d.membership, MembershipTable: membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions, RedactionsTable: redactions,
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
} }
return &d, nil return nil
} }
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {

View file

@ -123,15 +123,16 @@ const (
type Membership interface { type Membership interface {
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error) SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined. // counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
} }
type Published interface { type Published interface {

View file

@ -59,6 +59,7 @@ const defaultMessagesLimit = 10
// OnIncomingMessagesRequest implements the /messages endpoint from the // OnIncomingMessagesRequest implements the /messages endpoint from the
// client-server API. // client-server API.
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
// nolint:gocyclo
func OnIncomingMessagesRequest( func OnIncomingMessagesRequest(
req *http.Request, db storage.Database, roomID string, device *userapi.Device, req *http.Request, db storage.Database, roomID string, device *userapi.Device,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -67,6 +68,19 @@ func OnIncomingMessagesRequest(
) util.JSONResponse { ) util.JSONResponse {
var err error var err error
// check if the user has already forgotten about this room
isForgotten, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI)
if err != nil {
return jsonerror.InternalServerError()
}
if isForgotten {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("user already forgot about this room"),
}
}
// Extract parameters from the request's URL. // Extract parameters from the request's URL.
// Pagination tokens. // Pagination tokens.
var fromStream *types.StreamingToken var fromStream *types.StreamingToken
@ -182,6 +196,19 @@ func OnIncomingMessagesRequest(
} }
} }
func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.RoomserverInternalAPI) (bool, error) {
req := api.QueryMembershipForUserRequest{
RoomID: roomID,
UserID: userID,
}
resp := api.QueryMembershipForUserResponse{}
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
return false, err
}
return resp.IsRoomForgotten, nil
}
// retrieveEvents retrieves events from the local database for a request on // retrieveEvents retrieves events from the local database for a request on
// /messages. If there's not enough events to retrieve, it asks another // /messages. If there's not enough events to retrieve, it asks another
// homeserver in the room for older events. // homeserver in the room for older events.

View file

@ -486,3 +486,7 @@ Inbound federation rejects typing notifications from wrong remote
Should not be able to take over the room by pretending there is no PL event Should not be able to take over the room by pretending there is no PL event
Can get rooms/{roomId}/state for a departed room (SPEC-216) Can get rooms/{roomId}/state for a departed room (SPEC-216)
Users cannot set notifications powerlevel higher than their own Users cannot set notifications powerlevel higher than their own
Forgotten room messages cannot be paginated
Forgetting room does not show up in v2 /sync
Can forget room you've been kicked from
Can re-join room if re-invited