Variable and type clarifications

Signed-off-by: Alex Chen <minecnly@gmail.com>
This commit is contained in:
Cnly 2019-08-01 11:03:24 +08:00
parent ac775bb79d
commit 1772d830da
3 changed files with 15 additions and 11 deletions

View file

@ -56,6 +56,9 @@ type redactionStatements struct {
bulkUpdateValidationStatusStmt *sql.Stmt bulkUpdateValidationStatusStmt *sql.Stmt
} }
// redactedToRedactionMap is a map in the form map[redactedEventID]redactionEventNID.
type redactedToRedactionMap map[string]types.EventNID
func (s *redactionStatements) prepare(db *sql.DB) (err error) { func (s *redactionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(redactionsSchema) _, err = db.Exec(redactionsSchema)
if err != nil { if err != nil {
@ -81,13 +84,14 @@ func (s *redactionStatements) insertRedaction(
return err return err
} }
// bulkSelectRedaction returns the redactions for the given event IDs.
// Return values validated and unvalidated are both map[redactedEventID]redactedByNID.
func (s *redactionStatements) bulkSelectRedaction( func (s *redactionStatements) bulkSelectRedaction(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
eventIDs []string, eventIDs []string,
) ( ) (
validated map[string]types.EventNID, validated, unvalidated redactedToRedactionMap,
unvalidated map[string]types.EventNID,
err error, err error,
) { ) {
stmt := common.TxStmt(txn, s.bulkSelectRedactionStmt) stmt := common.TxStmt(txn, s.bulkSelectRedactionStmt)

View file

@ -295,7 +295,7 @@ func (d *Database) applyRedactions(
} }
if len(unvalidatedRedactions) != 0 { if len(unvalidatedRedactions) != 0 {
var newlyValidated map[string]types.EventNID var newlyValidated redactedToRedactionMap
if newlyValidated, err = d.validateRedactions( if newlyValidated, err = d.validateRedactions(
ctx, unvalidatedRedactions, redactionNIDToEvent, eventIDToEventPointer, ctx, unvalidatedRedactions, redactionNIDToEvent, eventIDToEventPointer,
); err != nil { ); err != nil {
@ -325,7 +325,7 @@ func (d *Database) applyRedactions(
func (d *Database) fetchRedactionEvents( func (d *Database) fetchRedactionEvents(
ctx context.Context, ctx context.Context,
validatedRedactions, unvalidatedRedactions map[string]types.EventNID, validatedRedactions, unvalidatedRedactions redactedToRedactionMap,
) (redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event, err error) { ) (redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event, err error) {
redactionEventsToFetch := make([]types.EventNID, 0, len(validatedRedactions)+len(unvalidatedRedactions)) redactionEventsToFetch := make([]types.EventNID, 0, len(validatedRedactions)+len(unvalidatedRedactions))
for _, nid := range validatedRedactions { for _, nid := range validatedRedactions {
@ -354,15 +354,15 @@ func (d *Database) fetchRedactionEvents(
func (d *Database) validateRedactions( func (d *Database) validateRedactions(
ctx context.Context, ctx context.Context,
unvalidatedRedactions map[string]types.EventNID, unvalidatedRedactions redactedToRedactionMap,
redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event, redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event,
eventIDToEvent map[string]*gomatrixserverlib.Event, redactedIDToEvent map[string]*gomatrixserverlib.Event,
) (validatedRedactions map[string]types.EventNID, err error) { ) (validatedRedactions redactedToRedactionMap, err error) {
validatedRedactions = make(map[string]types.EventNID, len(unvalidatedRedactions)) validatedRedactions = make(redactedToRedactionMap, len(unvalidatedRedactions))
for redactedEventID, redactedByNID := range unvalidatedRedactions { for redactedEventID, redactedByNID := range unvalidatedRedactions {
badEvents, needPowerLevelCheck, validationErr := common.ValidateRedaction( badEvents, needPowerLevelCheck, validationErr := common.ValidateRedaction(
eventIDToEvent[redactedEventID], redactionNIDToEvent[redactedByNID], redactedIDToEvent[redactedEventID], redactionNIDToEvent[redactedByNID],
) )
if validationErr != nil { if validationErr != nil {
return nil, validationErr return nil, validationErr

View file

@ -982,13 +982,13 @@ func (d *SyncServerDatasource) validateRedactions(
txn *sql.Tx, txn *sql.Tx,
unvalidatedRedactions redactedToRedactionMap, unvalidatedRedactions redactedToRedactionMap,
redactionIDToEvent map[string]*gomatrixserverlib.Event, redactionIDToEvent map[string]*gomatrixserverlib.Event,
eventIDToEvent map[string]*gomatrixserverlib.Event, redactedIDToEvent map[string]*gomatrixserverlib.Event,
) (validatedRedactions redactedToRedactionMap, err error) { ) (validatedRedactions redactedToRedactionMap, err error) {
validatedRedactions = make(redactedToRedactionMap, len(unvalidatedRedactions)) validatedRedactions = make(redactedToRedactionMap, len(unvalidatedRedactions))
for redactedEventID, redactedByID := range unvalidatedRedactions { for redactedEventID, redactedByID := range unvalidatedRedactions {
badEvents, needPowerLevelCheck, validationErr := common.ValidateRedaction( badEvents, needPowerLevelCheck, validationErr := common.ValidateRedaction(
eventIDToEvent[redactedEventID], redactionIDToEvent[redactedByID], redactedIDToEvent[redactedEventID], redactionIDToEvent[redactedByID],
) )
if validationErr != nil { if validationErr != nil {
return nil, validationErr return nil, validationErr