diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index eadcfd1ab..9dfff0f20 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -376,26 +376,6 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/filter", - httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return PutFilter(req, device, accountDB, vars["userId"]) - }), - ).Methods(http.MethodPost, http.MethodOptions) - - r0mux.Handle("/user/{userId}/filter/{filterId}", - httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - return GetFilter(req, device, accountDB, vars["userId"], vars["filterId"]) - }), - ).Methods(http.MethodGet, http.MethodOptions) - // Riot user settings r0mux.Handle("/profile/{userID}", diff --git a/clientapi/routing/filter.go b/syncapi/routing/filter.go similarity index 85% rename from clientapi/routing/filter.go rename to syncapi/routing/filter.go index 6520e6e40..1505e2a4b 100644 --- a/clientapi/routing/filter.go +++ b/syncapi/routing/filter.go @@ -19,15 +19,15 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} func GetFilter( - req *http.Request, device *api.Device, accountDB accounts.Database, userID string, filterID string, + req *http.Request, device *api.Device, syncDB storage.Database, userID string, filterID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -41,7 +41,7 @@ func GetFilter( return jsonerror.InternalServerError() } - filter, err := accountDB.GetFilter(req.Context(), localpart, filterID) + filter, err := syncDB.GetFilter(req.Context(), localpart, filterID) if err != nil { //TODO better error handling. This error message is *probably* right, // but if there are obscure db errors, this will also be returned, @@ -64,7 +64,7 @@ type filterResponse struct { //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter func PutFilter( - req *http.Request, device *api.Device, accountDB accounts.Database, userID string, + req *http.Request, device *api.Device, syncDB storage.Database, userID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -93,9 +93,9 @@ func PutFilter( } } - filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter) + filterID, err := syncDB.PutFilter(req.Context(), localpart, &filter) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.PutFilter failed") + util.GetLogger(req.Context()).WithError(err).Error("syncDB.PutFilter failed") return jsonerror.InternalServerError() } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 5744de05a..a98955c57 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -55,4 +55,24 @@ func Setup( } return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, rsAPI, cfg) })).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/user/{userId}/filter", + httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return PutFilter(req, device, syncDB, vars["userId"]) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + r0mux.Handle("/user/{userId}/filter/{filterId}", + httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetFilter(req, device, syncDB, vars["userId"], vars["filterId"]) + }), + ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index c693326b4..c4dae4d09 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -128,4 +128,12 @@ type Database interface { CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error) // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent. SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error) + // GetFilter looks up the filter associated with a given local user and filter ID. + // Returns a filter structure. Otherwise returns an error if no such filter exists + // or if there was an error talking to the database. + GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) + // PutFilter puts the passed filter into the database. + // Returns the filterID as a string. Otherwise returns an error if something + // goes wrong. + PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) } diff --git a/userapi/storage/accounts/postgres/filter_table.go b/syncapi/storage/postgres/filter_table.go similarity index 90% rename from userapi/storage/accounts/postgres/filter_table.go rename to syncapi/storage/postgres/filter_table.go index c54e4bc42..3747caf7b 100644 --- a/userapi/storage/accounts/postgres/filter_table.go +++ b/syncapi/storage/postgres/filter_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -53,24 +54,25 @@ type filterStatements struct { insertFilterStmt *sql.Stmt } -func (s *filterStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(filterSchema) +func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) { + _, err := db.Exec(filterSchema) if err != nil { - return + return nil, err } + s := &filterStatements{} if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { - return + return nil, err } if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { - return + return nil, err } if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *filterStatements) selectFilter( +func (s *filterStatements) SelectFilter( ctx context.Context, localpart string, filterID string, ) (*gomatrixserverlib.Filter, error) { // Retrieve filter from database (stored as canonical JSON) @@ -88,7 +90,7 @@ func (s *filterStatements) selectFilter( return &filter, nil } -func (s *filterStatements) insertFilter( +func (s *filterStatements) InsertFilter( ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 573586cc7..10c1b37c7 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -71,6 +71,10 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S if err != nil { return nil, err } + filter, err := NewPostgresFilterTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, Invites: invites, @@ -79,6 +83,7 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S Topology: topology, CurrentRoomState: currState, BackwardExtremities: backwardExtremities, + Filter: filter, SendToDevice: sendToDevice, SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index f84dc341e..7a84dd151 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -43,6 +43,7 @@ type Database struct { CurrentRoomState tables.CurrentRoomState BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice + Filter tables.Filter SendToDeviceWriter *sqlutil.TransactionWriter EDUCache *cache.EDUCache } @@ -545,6 +546,18 @@ func (d *Database) addEDUDeltaToResponse( return } +func (d *Database) GetFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + return d.Filter.SelectFilter(ctx, localpart, filterID) +} + +func (d *Database) PutFilter( + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, +) (string, error) { + return d.Filter.InsertFilter(ctx, filter, localpart) +} + func (d *Database) IncrementalSync( ctx context.Context, res *types.Response, device userapi.Device, diff --git a/userapi/storage/accounts/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go similarity index 91% rename from userapi/storage/accounts/sqlite3/filter_table.go rename to syncapi/storage/sqlite3/filter_table.go index 7f1a0c249..4cd74e8f2 100644 --- a/userapi/storage/accounts/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -54,24 +55,25 @@ type filterStatements struct { insertFilterStmt *sql.Stmt } -func (s *filterStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(filterSchema) +func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { + _, err := db.Exec(filterSchema) if err != nil { - return + return nil, err } + s := &filterStatements{} if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { - return + return nil, err } if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { - return + return nil, err } if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *filterStatements) selectFilter( +func (s *filterStatements) SelectFilter( ctx context.Context, localpart string, filterID string, ) (*gomatrixserverlib.Filter, error) { // Retrieve filter from database (stored as canonical JSON) @@ -89,7 +91,7 @@ func (s *filterStatements) selectFilter( return &filter, nil } -func (s *filterStatements) insertFilter( +func (s *filterStatements) InsertFilter( ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 51cdbe325..c85db5a4f 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -87,6 +87,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + filter, err := NewSqliteFilterTable(d.db) + if err != nil { + return err + } d.Database = shared.Database{ DB: d.db, Invites: invites, @@ -95,6 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) { BackwardExtremities: bwExtrem, CurrentRoomState: roomState, Topology: topology, + Filter: filter, SendToDevice: sendToDevice, SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 246dc6955..28a1786b3 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -133,3 +133,8 @@ type SendToDevice interface { DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) } + +type Filter interface { + SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) + InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) +} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index c6692879b..9ed33e1b9 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -52,8 +52,6 @@ type Database interface { RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) - PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) } diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index e55099800..f0b11bfdb 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -40,7 +40,6 @@ type Database struct { memberships membershipStatements accountDatas accountDataStatements threepids threepidStatements - filter filterStatements serverName gomatrixserverlib.ServerName } @@ -75,11 +74,7 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serve if err = t.prepare(db); err != nil { return nil, err } - f := filterStatements{} - if err = f.prepare(db); err != nil { - return nil, err - } - return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil + return &Database{db, partitions, a, p, m, ac, t, serverName}, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -396,24 +391,6 @@ func (d *Database) GetThreePIDsForLocalpart( return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) } -// GetFilter looks up the filter associated with a given local user and filter ID. -// Returns a filter structure. Otherwise returns an error if no such filter exists -// or if there was an error talking to the database. -func (d *Database) GetFilter( - ctx context.Context, localpart string, filterID string, -) (*gomatrixserverlib.Filter, error) { - return d.filter.selectFilter(ctx, localpart, filterID) -} - -// PutFilter puts the passed filter into the database. -// Returns the filterID as a string. Otherwise returns an error if something -// goes wrong. -func (d *Database) PutFilter( - ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, -) (string, error) { - return d.filter.insertFilter(ctx, filter, localpart) -} - // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index d84f25b1f..e965df4f9 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -39,7 +39,6 @@ type Database struct { memberships membershipStatements accountDatas accountDataStatements threepids threepidStatements - filter filterStatements serverName gomatrixserverlib.ServerName createAccountMu sync.Mutex @@ -80,11 +79,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = t.prepare(db); err != nil { return nil, err } - f := filterStatements{} - if err = f.prepare(db); err != nil { - return nil, err - } - return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil + return &Database{db, partitions, a, p, m, ac, t, serverName, sync.Mutex{}}, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -410,24 +405,6 @@ func (d *Database) GetThreePIDsForLocalpart( return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) } -// GetFilter looks up the filter associated with a given local user and filter ID. -// Returns a filter structure. Otherwise returns an error if no such filter exists -// or if there was an error talking to the database. -func (d *Database) GetFilter( - ctx context.Context, localpart string, filterID string, -) (*gomatrixserverlib.Filter, error) { - return d.filter.selectFilter(ctx, localpart, filterID) -} - -// PutFilter puts the passed filter into the database. -// Returns the filterID as a string. Otherwise returns an error if something -// goes wrong. -func (d *Database) PutFilter( - ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, -) (string, error) { - return d.filter.insertFilter(ctx, filter, localpart) -} - // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken.