Move filter table to syncapi where it is used

This commit is contained in:
Kegan Dougal 2020-06-26 14:06:23 +01:00
parent 164057a3be
commit 79cf546d92
13 changed files with 86 additions and 94 deletions

View file

@ -376,26 +376,6 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutFilter(req, device, accountDB, vars["userId"])
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetFilter(req, device, accountDB, vars["userId"], vars["filterId"])
}),
).Methods(http.MethodGet, http.MethodOptions)
// Riot user settings // Riot user settings
r0mux.Handle("/profile/{userID}", r0mux.Handle("/profile/{userID}",

View file

@ -19,15 +19,15 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
func GetFilter( func GetFilter(
req *http.Request, device *api.Device, accountDB accounts.Database, userID string, filterID string, req *http.Request, device *api.Device, syncDB storage.Database, userID string, filterID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
@ -41,7 +41,7 @@ func GetFilter(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
filter, err := accountDB.GetFilter(req.Context(), localpart, filterID) filter, err := syncDB.GetFilter(req.Context(), localpart, filterID)
if err != nil { if err != nil {
//TODO better error handling. This error message is *probably* right, //TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned, // but if there are obscure db errors, this will also be returned,
@ -64,7 +64,7 @@ type filterResponse struct {
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
func PutFilter( func PutFilter(
req *http.Request, device *api.Device, accountDB accounts.Database, userID string, req *http.Request, device *api.Device, syncDB storage.Database, userID string,
) util.JSONResponse { ) util.JSONResponse {
if userID != device.UserID { if userID != device.UserID {
return util.JSONResponse{ return util.JSONResponse{
@ -93,9 +93,9 @@ func PutFilter(
} }
} }
filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter) filterID, err := syncDB.PutFilter(req.Context(), localpart, &filter)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.PutFilter failed") util.GetLogger(req.Context()).WithError(err).Error("syncDB.PutFilter failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -55,4 +55,24 @@ func Setup(
} }
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, rsAPI, cfg) return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, rsAPI, cfg)
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutFilter(req, device, syncDB, vars["userId"])
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}",
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetFilter(req, device, syncDB, vars["userId"], vars["filterId"])
}),
).Methods(http.MethodGet, http.MethodOptions)
} }

View file

@ -128,4 +128,12 @@ type Database interface {
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error) SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
} }

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -53,24 +54,25 @@ type filterStatements struct {
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
} }
func (s *filterStatements) prepare(db *sql.DB) (err error) { func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
_, err = db.Exec(filterSchema) _, err := db.Exec(filterSchema)
if err != nil { if err != nil {
return return nil, err
} }
s := &filterStatements{}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return return nil, err
} }
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return return nil, err
} }
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return return nil, err
} }
return return s, nil
} }
func (s *filterStatements) selectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) (*gomatrixserverlib.Filter, error) {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
@ -88,7 +90,7 @@ func (s *filterStatements) selectFilter(
return &filter, nil return &filter, nil
} }
func (s *filterStatements) insertFilter( func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string

View file

@ -71,6 +71,10 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S
if err != nil { if err != nil {
return nil, err return nil, err
} }
filter, err := NewPostgresFilterTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
@ -79,6 +83,7 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S
Topology: topology, Topology: topology,
CurrentRoomState: currState, CurrentRoomState: currState,
BackwardExtremities: backwardExtremities, BackwardExtremities: backwardExtremities,
Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewTransactionWriter(), SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),

View file

@ -43,6 +43,7 @@ type Database struct {
CurrentRoomState tables.CurrentRoomState CurrentRoomState tables.CurrentRoomState
BackwardExtremities tables.BackwardsExtremities BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice SendToDevice tables.SendToDevice
Filter tables.Filter
SendToDeviceWriter *sqlutil.TransactionWriter SendToDeviceWriter *sqlutil.TransactionWriter
EDUCache *cache.EDUCache EDUCache *cache.EDUCache
} }
@ -545,6 +546,18 @@ func (d *Database) addEDUDeltaToResponse(
return return
} }
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.Filter.SelectFilter(ctx, localpart, filterID)
}
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.Filter.InsertFilter(ctx, filter, localpart)
}
func (d *Database) IncrementalSync( func (d *Database) IncrementalSync(
ctx context.Context, res *types.Response, ctx context.Context, res *types.Response,
device userapi.Device, device userapi.Device,

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -54,24 +55,25 @@ type filterStatements struct {
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
} }
func (s *filterStatements) prepare(db *sql.DB) (err error) { func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
_, err = db.Exec(filterSchema) _, err := db.Exec(filterSchema)
if err != nil { if err != nil {
return return nil, err
} }
s := &filterStatements{}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return return nil, err
} }
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
return return nil, err
} }
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return return nil, err
} }
return return s, nil
} }
func (s *filterStatements) selectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) (*gomatrixserverlib.Filter, error) {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
@ -89,7 +91,7 @@ func (s *filterStatements) selectFilter(
return &filter, nil return &filter, nil
} }
func (s *filterStatements) insertFilter( func (s *filterStatements) InsertFilter(
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) { ) (filterID string, err error) {
var existingFilterID string var existingFilterID string

View file

@ -87,6 +87,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
filter, err := NewSqliteFilterTable(d.db)
if err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Invites: invites, Invites: invites,
@ -95,6 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
BackwardExtremities: bwExtrem, BackwardExtremities: bwExtrem,
CurrentRoomState: roomState, CurrentRoomState: roomState,
Topology: topology, Topology: topology,
Filter: filter,
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewTransactionWriter(), SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(), EDUCache: cache.New(),

View file

@ -133,3 +133,8 @@ type SendToDevice interface {
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
} }
type Filter interface {
SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
}

View file

@ -52,8 +52,6 @@ type Database interface {
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
} }

View file

@ -40,7 +40,6 @@ type Database struct {
memberships membershipStatements memberships membershipStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -75,11 +74,7 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serve
if err = t.prepare(db); err != nil { if err = t.prepare(db); err != nil {
return nil, err return nil, err
} }
f := filterStatements{} return &Database{db, partitions, a, p, m, ac, t, serverName}, nil
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
@ -396,24 +391,6 @@ func (d *Database) GetThreePIDsForLocalpart(
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
} }
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.

View file

@ -39,7 +39,6 @@ type Database struct {
memberships membershipStatements memberships membershipStatements
accountDatas accountDataStatements accountDatas accountDataStatements
threepids threepidStatements threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
createAccountMu sync.Mutex createAccountMu sync.Mutex
@ -80,11 +79,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = t.prepare(db); err != nil { if err = t.prepare(db); err != nil {
return nil, err return nil, err
} }
f := filterStatements{} return &Database{db, partitions, a, p, m, ac, t, serverName, sync.Mutex{}}, nil
if err = f.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil
} }
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
@ -410,24 +405,6 @@ func (d *Database) GetThreePIDsForLocalpart(
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
} }
// GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
// PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.