Add server_notice_room_id and methods to update/get it

This commit is contained in:
Till Faelligen 2022-03-07 09:41:25 +01:00
parent 519ea13510
commit 699617ee4d
12 changed files with 266 additions and 50 deletions

View file

@ -57,6 +57,8 @@ type UserInternalAPI interface {
QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error
QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error
PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error
SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error)
UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error)
} }
type PerformKeyBackupRequest struct { type PerformKeyBackupRequest struct {
@ -348,12 +350,12 @@ type QueryOpenIDTokenResponse struct {
ExpiresAtMS int64 ExpiresAtMS int64
} }
// QueryPolicyVersionRequest is the response for QueryPolicyVersionRequest // QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest
type QueryPolicyVersionRequest struct { type QueryPolicyVersionRequest struct {
LocalPart string Localpart string
} }
// QueryPolicyVersionResponsestruct is the response for QueryPolicyVersionResponsestruct // QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest
type QueryPolicyVersionResponse struct { type QueryPolicyVersionResponse struct {
PolicyVersion string PolicyVersion string
} }
@ -365,18 +367,36 @@ type QueryOutdatedPolicyRequest struct {
// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest // QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest
type QueryOutdatedPolicyResponse struct { type QueryOutdatedPolicyResponse struct {
OutdatedUsers []string UserLocalparts []string
} }
// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest // UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest
type UpdatePolicyVersionRequest struct { type UpdatePolicyVersionRequest struct {
PolicyVersion, LocalPart string PolicyVersion, Localpart string
ServerNoticeUpdate bool ServerNoticeUpdate bool
} }
// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest // UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest
type UpdatePolicyVersionResponse struct{} type UpdatePolicyVersionResponse struct{}
// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest
type QueryServerNoticeRoomRequest struct {
Localpart string
}
// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest
type QueryServerNoticeRoomResponse struct {
RoomID string
}
// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest
type UpdateServerNoticeRoomRequest struct {
Localpart, RoomID string
}
// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest
type UpdateServerNoticeRoomResponse struct{}
// Device represents a client's device (mobile, web, etc) // Device represents a client's device (mobile, web, etc)
type Device struct { type Device struct {
ID string ID string

View file

@ -167,6 +167,18 @@ func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, r
return err return err
} }
func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error {
err := t.Impl.SelectServerNoticeRoomID(ctx, req, res)
util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error {
err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res)
util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -769,7 +769,7 @@ func (a *UserInternalAPI) QueryPolicyVersion(
res *api.QueryPolicyVersionResponse, res *api.QueryPolicyVersionResponse,
) error { ) error {
var err error var err error
res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.LocalPart) res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart)
if err != nil { if err != nil {
return err return err
} }
@ -783,7 +783,7 @@ func (a *UserInternalAPI) QueryOutdatedPolicy(
res *api.QueryOutdatedPolicyResponse, res *api.QueryOutdatedPolicyResponse,
) error { ) error {
var err error var err error
res.OutdatedUsers, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion) res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion)
if err != nil { if err != nil {
return err return err
} }
@ -796,5 +796,26 @@ func (a *UserInternalAPI) PerformUpdatePolicyVersion(
req *api.UpdatePolicyVersionRequest, req *api.UpdatePolicyVersionRequest,
res *api.UpdatePolicyVersionResponse, res *api.UpdatePolicyVersionResponse,
) error { ) error {
return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.LocalPart, req.ServerNoticeUpdate) return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate)
}
func (a *UserInternalAPI) SelectServerNoticeRoomID(
ctx context.Context,
req *api.QueryServerNoticeRoomRequest,
res *api.QueryServerNoticeRoomResponse,
) (err error) {
roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart)
if err != nil {
return err
}
res.RoomID = roomID
return nil
}
func (a *UserInternalAPI) UpdateServerNoticeRoomID(
ctx context.Context,
req *api.UpdateServerNoticeRoomRequest,
res *api.UpdateServerNoticeRoomResponse,
) (err error) {
return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID)
} }

View file

@ -41,6 +41,7 @@ const (
PerformPusherDeletionPath = "/pushserver/performPusherDeletion" PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
PerformPushRulesPutPath = "/pushserver/performPushRulesPut" PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion" PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion"
PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom"
QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
@ -55,6 +56,7 @@ const (
QueryNotificationsPath = "/pushserver/queryNotifications" QueryNotificationsPath = "/pushserver/queryNotifications"
QueryPolicyVersionPath = "/userapi/queryPolicyVersion" QueryPolicyVersionPath = "/userapi/queryPolicyVersion"
QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy" QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy"
QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -337,3 +339,19 @@ func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, re
apiURL := h.apiURL + PerformUpdatePolicyVersionPath apiURL := h.apiURL + PerformUpdatePolicyVersionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID")
defer span.Finish()
apiURL := h.apiURL + QueryServerNoticeRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID")
defer span.Finish()
apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -389,4 +389,32 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryServerNoticeRoomPath,
httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse {
request := api.QueryServerNoticeRoomRequest{}
response := api.QueryServerNoticeRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.SelectServerNoticeRoomID(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath,
httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse {
request := api.UpdateServerNoticeRoomRequest{}
response := api.UpdateServerNoticeRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response)
if err != nil {
return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -55,6 +55,8 @@ type Database interface {
GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error)
GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error
SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error)
UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error)
// Key backups // Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)

View file

@ -47,7 +47,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- The policy version this user has accepted -- The policy version this user has accepted
policy_version TEXT, policy_version TEXT,
-- The policy version the user received from the server notices room -- The policy version the user received from the server notices room
policy_version_sent TEXT policy_version_sent TEXT,
server_notice_room_id TEXT
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
@ -85,6 +86,12 @@ const updatePolicyVersionSQL = "" +
const updatePolicyVersionServerNoticeSQL = "" + const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
const selectServerNoticeRoomSQL = "" +
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
const updateServerNoticeRoomSQL = "" +
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt updatePasswordStmt *sql.Stmt
@ -96,6 +103,8 @@ type accountsStatements struct {
batchSelectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt updatePolicyVersionServerNoticeStmt *sql.Stmt
selectServerNoticeRoomStmt *sql.Stmt
updateServerNoticeRoomStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -118,6 +127,8 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -130,12 +141,7 @@ func (s *accountsStatements) InsertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error _, err := stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -251,3 +257,36 @@ func (s *accountsStatements) UpdatePolicyVersion(
_, err = stmt.ExecContext(ctx, policyVersion, localpart) _, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err return err
} }
// SelectServerNoticeRoomID queries the server notice room ID.
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := s.selectServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil {
return "", err
}
if roomIDNull.Valid {
return roomIDNull.String, nil
}
return "", nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := s.updateServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}

View file

@ -20,13 +20,24 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil return nil
} }
func DownAddPolicyVersion(tx *sql.Tx) error { func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -791,3 +791,16 @@ func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, local
}) })
return return
} }
// SelectServerNoticeRoomID returns the server notice room, if one is set.
func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) {
return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart)
}
// UpdateServerNoticeRoomID updates the server notice room
func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID)
})
return
}

View file

@ -47,7 +47,8 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- The policy version this user has accepted -- The policy version this user has accepted
policy_version TEXT, policy_version TEXT,
-- The policy version the user received from the server notices room -- The policy version the user received from the server notices room
policy_version_sent TEXT policy_version_sent TEXT,
server_notice_room_id TEXT
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
@ -83,6 +84,12 @@ const updatePolicyVersionSQL = "" +
const updatePolicyVersionServerNoticeSQL = "" + const updatePolicyVersionServerNoticeSQL = "" +
"UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2"
const selectServerNoticeRoomSQL = "" +
"SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1"
const updateServerNoticeRoomSQL = "" +
"UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
@ -95,6 +102,8 @@ type accountsStatements struct {
batchSelectPrivacyPolicyStmt *sql.Stmt batchSelectPrivacyPolicyStmt *sql.Stmt
updatePolicyVersionStmt *sql.Stmt updatePolicyVersionStmt *sql.Stmt
updatePolicyVersionServerNoticeStmt *sql.Stmt updatePolicyVersionServerNoticeStmt *sql.Stmt
selectServerNoticeRoomStmt *sql.Stmt
updateServerNoticeRoomStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
@ -118,6 +127,8 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
{&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL},
{&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, {&s.updatePolicyVersionStmt, updatePolicyVersionSQL},
{&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL},
{&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL},
{&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -130,12 +141,7 @@ func (s *accountsStatements) InsertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion)
if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion)
} else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, "")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -254,3 +260,36 @@ func (s *accountsStatements) UpdatePolicyVersion(
_, err = stmt.ExecContext(ctx, policyVersion, localpart) _, err = stmt.ExecContext(ctx, policyVersion, localpart)
return err return err
} }
// SelectServerNoticeRoomID queries the server notice room ID.
func (s *accountsStatements) SelectServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart string,
) (roomID string, err error) {
stmt := s.selectServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
roomIDNull := sql.NullString{}
row := stmt.QueryRowContext(ctx, localpart)
err = row.Scan(&roomIDNull)
if err != nil {
return "", err
}
if roomIDNull.Valid {
return roomIDNull.String, nil
}
return "", nil
}
// UpdateServerNoticeRoomID sets the server notice room ID.
func (s *accountsStatements) UpdateServerNoticeRoomID(
ctx context.Context, txn *sql.Tx, localpart, roomID string,
) (err error) {
stmt := s.updateServerNoticeRoomStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
_, err = stmt.ExecContext(ctx, roomID, localpart)
return
}

View file

@ -20,12 +20,23 @@ func UpAddPolicyVersion(tx *sql.Tx) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
_, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil return nil
} }
func DownAddPolicyVersion(tx *sql.Tx) error { func DownAddPolicyVersion(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;" + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;")
"ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
_, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -41,6 +41,8 @@ type AccountsTable interface {
SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error) SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error)
BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error) BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error)
UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error) UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error)
SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error)
UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error)
} }
type DevicesTable interface { type DevicesTable interface {