diff --git a/clientapi/auth/storage/accounts/interface.go b/clientapi/auth/storage/accounts/interface.go index 9f6e3e1ea..a5052b047 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/clientapi/auth/storage/accounts/interface.go @@ -33,6 +33,7 @@ type Database interface { CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) + GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) diff --git a/clientapi/auth/storage/accounts/postgres/membership_table.go b/clientapi/auth/storage/accounts/postgres/membership_table.go index 426c2d6ac..ceabf2615 100644 --- a/clientapi/auth/storage/accounts/postgres/membership_table.go +++ b/clientapi/auth/storage/accounts/postgres/membership_table.go @@ -51,6 +51,9 @@ const selectMembershipsByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" +const selectRoomIDsByLocalPartSQL = "" + + "SELECT room_id FROM account_memberships WHERE localpart = $1" + const deleteMembershipsByEventIDsSQL = "" + "DELETE FROM account_memberships WHERE event_id = ANY($1)" @@ -59,6 +62,7 @@ type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt + selectRoomIDsByLocalPartStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -78,6 +82,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } + if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil { + return + } return } @@ -129,3 +136,23 @@ func (s *membershipStatements) selectMembershipsByLocalpart( } return memberships, rows.Err() } + +func (s *membershipStatements) selectRoomIDsByLocalPart( + ctx context.Context, localPart string, +) ([]string, error) { + stmt := s.selectRoomIDsByLocalPartStmt + rows, err := stmt.QueryContext(ctx, localPart) + if err != nil { + return nil, err + } + roomIDs := []string{} + defer rows.Close() // nolint: errcheck + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index 8115dca43..4a0a2060b 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -234,6 +234,16 @@ func (d *Database) GetMembershipInRoomByLocalpart( return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) } +// GetRoomIDsByLocalPart returns an array containing the room ids of all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetRoomIDsByLocalPart( + ctx context.Context, localpart string, +) ([]string, error) { + return d.memberships.selectRoomIDsByLocalPart(ctx, localpart) +} + // GetMembershipsByLocalpart returns an array containing the memberships for all // the rooms a user matching a given localpart is a member of // If no membership match the given localpart, returns an empty array diff --git a/clientapi/auth/storage/accounts/sqlite3/membership_table.go b/clientapi/auth/storage/accounts/sqlite3/membership_table.go index 38f21b7f3..d2f7c07e9 100644 --- a/clientapi/auth/storage/accounts/sqlite3/membership_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/membership_table.go @@ -51,6 +51,9 @@ const selectMembershipsByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" +const selectRoomIDsByLocalPartSQL = "" + + "SELECT room_id FROM account_memberships WHERE localpart = $1" + const deleteMembershipsByEventIDsSQL = "" + "DELETE FROM account_memberships WHERE event_id IN ($1)" @@ -58,6 +61,7 @@ type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt + selectRoomIDsByLocalPartStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -74,6 +78,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } + if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil { + return + } return } @@ -130,3 +137,22 @@ func (s *membershipStatements) selectMembershipsByLocalpart( return } +func (s *membershipStatements) selectRoomIDsByLocalPart( + ctx context.Context, localPart string, +) ([]string, error) { + stmt := s.selectRoomIDsByLocalPartStmt + rows, err := stmt.QueryContext(ctx, localPart) + if err != nil { + return nil, err + } + roomIDs := []string{} + defer rows.Close() // nolint: errcheck + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 9124640c6..bfb7b4ea1 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -253,6 +253,16 @@ func (d *Database) GetMembershipsByLocalpart( return d.memberships.selectMembershipsByLocalpart(ctx, localpart) } +// GetRoomIDsByLocalPart returns an array containing the room ids of all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetRoomIDsByLocalPart( + ctx context.Context, localpart string, +) ([]string, error) { + return d.memberships.selectRoomIDsByLocalPart(ctx, localpart) +} + // newMembership saves a new membership in the database. // If the event isn't a valid m.room.member event with type `join`, does nothing. // If an error occurred, returns the SQL error diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index a6899eeb8..0b846e5e3 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -17,6 +17,8 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/common/config" @@ -25,10 +27,14 @@ import ( "github.com/matrix-org/util" ) -type response struct { +type getMembershipResponse struct { Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` } +type getJoinedRoomsResponse struct { + JoinedRooms []string `json:"joined_rooms"` +} + // GetMemberships implements GET /rooms/{roomId}/members func GetMemberships( req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, @@ -55,6 +61,27 @@ func GetMemberships( return util.JSONResponse{ Code: http.StatusOK, - JSON: response{queryRes.JoinEvents}, + JSON: getMembershipResponse{queryRes.JoinEvents}, + } +} + +func GetJoinedRooms( + req *http.Request, + device *authtypes.Device, + accountsDB accounts.Database, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + joinedRooms, err := accountsDB.GetRoomIDsByLocalPart(req.Context(), localpart) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("accountsDB.GetRoomIDsByLocalPart failed") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: getJoinedRoomsResponse{joinedRooms}, } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f0841b796..3caf75efc 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -105,6 +105,12 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/joined_rooms", + common.MakeAuthAPI("joined_rooms", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + return GetJoinedRooms(req, device, accountDB) + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}", common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars, err := common.URLDecodeMapValues(mux.Vars(req)) @@ -355,6 +361,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/account/deactivate", + common.MakeAuthAPI("account_deactivate", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + return GetAssociated3PIDs(req, accountDB, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/account/3pid", common.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) diff --git a/go.mod b/go.mod index 0728a03cc..cd8e4903f 100644 --- a/go.mod +++ b/go.mod @@ -14,15 +14,19 @@ require ( github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 github.com/mattn/go-sqlite3 v2.0.2+incompatible + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.1 // indirect github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 github.com/opentracing/opentracing-go v1.1.0 github.com/pkg/errors v0.8.1 github.com/prometheus/client_golang v1.4.1 + github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect github.com/sirupsen/logrus v1.4.2 github.com/tidwall/gjson v1.6.0 // indirect github.com/tidwall/pretty v1.0.1 // indirect github.com/uber/jaeger-client-go v2.22.1+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible + go.uber.org/atomic v1.6.0 // indirect golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d gopkg.in/Shopify/sarama.v1 v1.20.1 gopkg.in/h2non/bimg.v1 v1.0.18 diff --git a/go.sum b/go.sum index af0876489..4b83e0f13 100644 --- a/go.sum +++ b/go.sum @@ -222,6 +222,8 @@ github.com/prometheus/procfs v0.0.8 h1:+fpWZdT24pJBiqJdAwYBjPSk+5YmQzYNPYzQsdzLk github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 h1:dY6ETXrvDG7Sa4vE8ZQG4yqWg6UnOcbqTAahkV813vQ= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= +github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= @@ -263,6 +265,8 @@ github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7V go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.3.0 h1:vs7fgriifsPbGdK3bNuMWapNn3qnZhCRXc19NRdq010= go.uber.org/atomic v1.3.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -279,6 +283,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -297,6 +302,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 h1:bjcUS9ztw9kFmmIxJInhon/0Is3p+EHBKNgquIzo1OI= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -321,6 +327,8 @@ golang.org/x/tools v0.0.0-20181130052023-1c3d964395ce/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=