Implement /joined_rooms (#911)

* Implemented /joined_rooms

* Removed account endpoint added by mistake

* trigger ci
This commit is contained in:
Prateek Sachan 2020-03-19 15:55:36 +05:30 committed by GitHub
parent ec38783192
commit dc06c69887
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 112 additions and 2 deletions

View file

@ -33,6 +33,7 @@ type Database interface {
CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) CreateGuestAccount(ctx context.Context) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)

View file

@ -53,6 +53,9 @@ const selectMembershipsByLocalpartSQL = "" +
const selectMembershipInRoomByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const selectRoomIDsByLocalPartSQL = "" +
"SELECT room_id FROM account_memberships WHERE localpart = $1"
const deleteMembershipsByEventIDsSQL = "" + const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id = ANY($1)" "DELETE FROM account_memberships WHERE event_id = ANY($1)"
@ -61,6 +64,7 @@ type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt
selectRoomIDsByLocalPartStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -80,6 +84,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return return
} }
if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil {
return
}
return return
} }
@ -131,3 +138,23 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
} }
return memberships, rows.Err() return memberships, rows.Err()
} }
func (s *membershipStatements) selectRoomIDsByLocalPart(
ctx context.Context, localPart string,
) ([]string, error) {
stmt := s.selectRoomIDsByLocalPartStmt
rows, err := stmt.QueryContext(ctx, localPart)
if err != nil {
return nil, err
}
roomIDs := []string{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
}

View file

@ -234,6 +234,16 @@ func (d *Database) GetMembershipInRoomByLocalpart(
return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
} }
// GetRoomIDsByLocalPart returns an array containing the room ids of all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetRoomIDsByLocalPart(
ctx context.Context, localpart string,
) ([]string, error) {
return d.memberships.selectRoomIDsByLocalPart(ctx, localpart)
}
// GetMembershipsByLocalpart returns an array containing the memberships for all // GetMembershipsByLocalpart returns an array containing the memberships for all
// the rooms a user matching a given localpart is a member of // the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array // If no membership match the given localpart, returns an empty array

View file

@ -51,6 +51,9 @@ const selectMembershipsByLocalpartSQL = "" +
const selectMembershipInRoomByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" +
"SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
const selectRoomIDsByLocalPartSQL = "" +
"SELECT room_id FROM account_memberships WHERE localpart = $1"
const deleteMembershipsByEventIDsSQL = "" + const deleteMembershipsByEventIDsSQL = "" +
"DELETE FROM account_memberships WHERE event_id IN ($1)" "DELETE FROM account_memberships WHERE event_id IN ($1)"
@ -58,6 +61,7 @@ type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt
selectMembershipsByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt
selectRoomIDsByLocalPartStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -74,6 +78,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
return return
} }
if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil {
return
}
return return
} }
@ -130,3 +137,22 @@ func (s *membershipStatements) selectMembershipsByLocalpart(
return return
} }
func (s *membershipStatements) selectRoomIDsByLocalPart(
ctx context.Context, localPart string,
) ([]string, error) {
stmt := s.selectRoomIDsByLocalPartStmt
rows, err := stmt.QueryContext(ctx, localPart)
if err != nil {
return nil, err
}
roomIDs := []string{}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
roomIDs = append(roomIDs, roomID)
}
return roomIDs, rows.Err()
}

View file

@ -253,6 +253,16 @@ func (d *Database) GetMembershipsByLocalpart(
return d.memberships.selectMembershipsByLocalpart(ctx, localpart) return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
} }
// GetRoomIDsByLocalPart returns an array containing the room ids of all
// the rooms a user matching a given localpart is a member of
// If no membership match the given localpart, returns an empty array
// If there was an issue during the retrieval, returns the SQL error
func (d *Database) GetRoomIDsByLocalPart(
ctx context.Context, localpart string,
) ([]string, error) {
return d.memberships.selectRoomIDsByLocalPart(ctx, localpart)
}
// newMembership saves a new membership in the database. // newMembership saves a new membership in the database.
// If the event isn't a valid m.room.member event with type `join`, does nothing. // If the event isn't a valid m.room.member event with type `join`, does nothing.
// If an error occurred, returns the SQL error // If an error occurred, returns the SQL error

View file

@ -17,6 +17,8 @@ package routing
import ( import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -25,10 +27,14 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type response struct { type getMembershipResponse struct {
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
} }
type getJoinedRoomsResponse struct {
JoinedRooms []string `json:"joined_rooms"`
}
// GetMemberships implements GET /rooms/{roomId}/members // GetMemberships implements GET /rooms/{roomId}/members
func GetMemberships( func GetMemberships(
req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool,
@ -55,6 +61,27 @@ func GetMemberships(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: response{queryRes.JoinEvents}, JSON: getMembershipResponse{queryRes.JoinEvents},
}
}
func GetJoinedRooms(
req *http.Request,
device *authtypes.Device,
accountsDB accounts.Database,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
joinedRooms, err := accountsDB.GetRoomIDsByLocalPart(req.Context(), localpart)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountsDB.GetRoomIDsByLocalPart failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: getJoinedRoomsResponse{joinedRooms},
} }
} }

View file

@ -105,6 +105,12 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/joined_rooms",
common.MakeAuthAPI("joined_rooms", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetJoinedRooms(req, device, accountDB)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}", r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}",
common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req)) vars, err := common.URLDecodeMapValues(mux.Vars(req))

1
go.mod
View file

@ -27,6 +27,7 @@ require (
github.com/tidwall/pretty v1.0.1 // indirect github.com/tidwall/pretty v1.0.1 // indirect
github.com/uber/jaeger-client-go v2.22.1+incompatible github.com/uber/jaeger-client-go v2.22.1+incompatible
github.com/uber/jaeger-lib v2.2.0+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible
go.uber.org/atomic v1.6.0 // indirect
golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d
gopkg.in/Shopify/sarama.v1 v1.20.1 gopkg.in/Shopify/sarama.v1 v1.20.1
gopkg.in/h2non/bimg.v1 v1.0.18 gopkg.in/h2non/bimg.v1 v1.0.18

View file

@ -218,3 +218,5 @@ Push rules come down in an initial /sync
Regular users can add and delete aliases in the default room configuration Regular users can add and delete aliases in the default room configuration
Regular users can add and delete aliases when m.room.aliases is restricted Regular users can add and delete aliases when m.room.aliases is restricted
GET /r0/capabilities is not public GET /r0/capabilities is not public
GET /joined_rooms lists newly-created room
/joined_rooms returns only joined rooms