Merge branch 'master' into hs/fix-appsevice-alias-queries-part-2

This commit is contained in:
Neil Alexander 2021-03-03 14:36:47 +00:00 committed by GitHub
commit 7535f88502
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
77 changed files with 834 additions and 867 deletions

View file

@ -102,7 +102,7 @@ linters-settings:
#local-prefixes: github.com/org/project #local-prefixes: github.com/org/project
gocyclo: gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20) # minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 13 min-complexity: 25
maligned: maligned:
# print struct with more effective memory layout or not, false by default # print struct with more effective memory layout or not, false by default
suggest-new: true suggest-new: true

View file

@ -1,5 +1,43 @@
# Changelog # Changelog
## Dendrite 0.3.11 (2021-03-02)
### Fixes
- **SECURITY:** A bug in SQLite mode which could cause the registration flow to complete unexpectedly for existing accounts has been fixed (PostgreSQL deployments are not affected)
- A panic in the federation sender has been fixed when shutting down destination queues
- The `/keys/upload` endpoint now correctly returns the number of one-time keys in response to an empty upload request
## Dendrite 0.3.10 (2021-02-17)
### Features
* In-memory caches will now gradually evict old entries, reducing idle memory usage
* Federation sender queues will now be fully unloaded when idle, reducing idle memory usage
* The `power_level_content_override` option is now supported in `/createRoom`
* The `/send` endpoint will now attempt more servers in the room when trying to fetch missing events or state
### Fixes
* A panic in the membership updater has been fixed
* Events in the sync API that weren't excluded from sync can no longer be incorrectly excluded from sync by backfill
* Retrieving remote media now correcly respects the locally configured maximum file size, even when the `Content-Length` header is unavailable
* The `/send` endpoint will no longer hit the database more than once to find servers in the room
## Dendrite 0.3.9 (2021-02-04)
### Features
* Performance of initial/complete syncs has been improved dramatically
* State events that can't be authed are now dropped when joining a room rather than unexpectedly causing the room join to fail
* State events that already appear in the timeline will no longer be requested from the sync API database more than once, which may reduce memory usage in some cases
### Fixes
* A crash at startup due to a conflict in the sync API account data has been fixed
* A crash at startup due to mismatched event IDs in the federation sender has been fixed
* A redundant check which may cause the roomserver memberships table to get out of sync has been removed
## Dendrite 0.3.8 (2021-01-28) ## Dendrite 0.3.8 (2021-01-28)
### Fixes ### Fixes

View file

@ -23,13 +23,13 @@ RUN apt-get update && apt-get -y install python
WORKDIR /build WORKDIR /build
ADD https://github.com/matrix-org/go-http-js-libp2p/archive/master.tar.gz /build/libp2p.tar.gz ADD https://github.com/matrix-org/go-http-js-libp2p/archive/master.tar.gz /build/libp2p.tar.gz
RUN tar xvfz libp2p.tar.gz RUN tar xvfz libp2p.tar.gz
ADD https://github.com/vector-im/riot-web/archive/matthew/p2p.tar.gz /build/p2p.tar.gz ADD https://github.com/vector-im/element-web/archive/matthew/p2p.tar.gz /build/p2p.tar.gz
RUN tar xvfz p2p.tar.gz RUN tar xvfz p2p.tar.gz
# Install deps for riot-web, symlink in libp2p repo and build that too # Install deps for element-web, symlink in libp2p repo and build that too
WORKDIR /build/riot-web-matthew-p2p WORKDIR /build/element-web-matthew-p2p
RUN yarn install RUN yarn install
RUN ln -s /build/go-http-js-libp2p-master /build/riot-web-matthew-p2p/node_modules/go-http-js-libp2p RUN ln -s /build/go-http-js-libp2p-master /build/element-web-matthew-p2p/node_modules/go-http-js-libp2p
RUN (cd node_modules/go-http-js-libp2p && yarn install) RUN (cd node_modules/go-http-js-libp2p && yarn install)
COPY --from=gobuild /build/dendrite-master/main.wasm ./src/vector/dendrite.wasm COPY --from=gobuild /build/dendrite-master/main.wasm ./src/vector/dendrite.wasm
# build it all # build it all
@ -108,4 +108,4 @@ server { \n\
} \n\ } \n\
}' > /etc/nginx/conf.d/default.conf }' > /etc/nginx/conf.d/default.conf
RUN sed -i 's/}/ application\/wasm wasm;\n}/g' /etc/nginx/mime.types RUN sed -i 's/}/ application\/wasm wasm;\n}/g' /etc/nginx/mime.types
COPY --from=jsbuild /build/riot-web-matthew-p2p/webapp /usr/share/nginx/html COPY --from=jsbuild /build/element-web-matthew-p2p/webapp /usr/share/nginx/html

View file

@ -14,7 +14,7 @@ RUN go build -trimpath -o bin/ ./cmd/generate-keys
FROM alpine:latest FROM alpine:latest
COPY --from=base /build/bin/* /usr/bin COPY --from=base /build/bin/* /usr/bin/
VOLUME /etc/dendrite VOLUME /etc/dendrite
WORKDIR /etc/dendrite WORKDIR /etc/dendrite

View file

@ -14,7 +14,7 @@ RUN go build -trimpath -o bin/ ./cmd/generate-keys
FROM alpine:latest FROM alpine:latest
COPY --from=base /build/bin/* /usr/bin COPY --from=base /build/bin/* /usr/bin/
VOLUME /etc/dendrite VOLUME /etc/dendrite
WORKDIR /etc/dendrite WORKDIR /etc/dendrite

View file

@ -309,12 +309,12 @@ user_api:
listen: http://0.0.0.0:7781 listen: http://0.0.0.0:7781
connect: http://user_api:7781 connect: http://user_api:7781
account_database: account_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_account?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_userapi_accounts?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
device_database: device_database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_device?sslmode=disable connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_userapi_devices?sslmode=disable
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -12,6 +12,7 @@ services:
- 8448:8448 - 8448:8448
volumes: volumes:
- ./config:/etc/dendrite - ./config:/etc/dendrite
- ./media:/var/dendrite/media
networks: networks:
- internal - internal

View file

@ -15,6 +15,7 @@ services:
command: mediaapi command: mediaapi
volumes: volumes:
- ./config:/etc/dendrite - ./config:/etc/dendrite
- ./media:/var/dendrite/media
networks: networks:
- internal - internal
@ -70,7 +71,7 @@ services:
volumes: volumes:
- ./config:/etc/dendrite - ./config:/etc/dendrite
networks: networks:
- internal - internal
signing_key_server: signing_key_server:
hostname: signing_key_server hostname: signing_key_server
@ -86,9 +87,9 @@ services:
image: matrixdotorg/dendrite-polylith:latest image: matrixdotorg/dendrite-polylith:latest
command: userapi command: userapi
volumes: volumes:
- ./config:/etc/dendrite - ./config:/etc/dendrite
networks: networks:
- internal - internal
appservice_api: appservice_api:
hostname: appservice_api hostname: appservice_api

View file

@ -1,5 +1,5 @@
#!/bin/sh #!/bin/sh
for db in account device mediaapi syncapi roomserver signingkeyserver keyserver federationsender appservice naffka; do for db in userapi_accounts userapi_devices mediaapi syncapi roomserver signingkeyserver keyserver federationsender appservice naffka; do
createdb -U dendrite -O dendrite dendrite_$db createdb -U dendrite -O dendrite dendrite_$db
done done

View file

@ -38,16 +38,17 @@ import (
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-client-r0-createroom
type createRoomRequest struct { type createRoomRequest struct {
Invite []string `json:"invite"` Invite []string `json:"invite"`
Name string `json:"name"` Name string `json:"name"`
Visibility string `json:"visibility"` Visibility string `json:"visibility"`
Topic string `json:"topic"` Topic string `json:"topic"`
Preset string `json:"preset"` Preset string `json:"preset"`
CreationContent map[string]interface{} `json:"creation_content"` CreationContent map[string]interface{} `json:"creation_content"`
InitialState []fledglingEvent `json:"initial_state"` InitialState []fledglingEvent `json:"initial_state"`
RoomAliasName string `json:"room_alias_name"` RoomAliasName string `json:"room_alias_name"`
GuestCanJoin bool `json:"guest_can_join"` GuestCanJoin bool `json:"guest_can_join"`
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
PowerLevelContentOverride json.RawMessage `json:"power_level_content_override"`
} }
const ( const (
@ -258,6 +259,18 @@ func createRoom(
var builtEvents []*gomatrixserverlib.HeaderedEvent var builtEvents []*gomatrixserverlib.HeaderedEvent
powerLevelContent := eventutil.InitialPowerLevelsContent(userID)
if r.PowerLevelContentOverride != nil {
// Merge powerLevelContentOverride fields by unmarshalling it atop the defaults
err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("malformed power_level_content_override"),
}
}
}
// send events into the room in order of: // send events into the room in order of:
// 1- m.room.create // 1- m.room.create
// 2- room creator join member // 2- room creator join member
@ -279,7 +292,7 @@ func createRoom(
eventsToMake := []fledglingEvent{ eventsToMake := []fledglingEvent{
{"m.room.create", "", r.CreationContent}, {"m.room.create", "", r.CreationContent},
{"m.room.member", userID, membershipContent}, {"m.room.member", userID, membershipContent},
{"m.room.power_levels", "", eventutil.InitialPowerLevelsContent(userID)}, {"m.room.power_levels", "", powerLevelContent},
{"m.room.join_rules", "", gomatrixserverlib.JoinRuleContent{JoinRule: joinRules}}, {"m.room.join_rules", "", gomatrixserverlib.JoinRuleContent{JoinRule: joinRules}},
{"m.room.history_visibility", "", eventutil.HistoryVisibilityContent{HistoryVisibility: historyVisibility}}, {"m.room.history_visibility", "", eventutil.HistoryVisibilityContent{HistoryVisibility: historyVisibility}},
} }

View file

@ -38,7 +38,10 @@ func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.De
return *resErr return *resErr
} }
uploadReq := &api.PerformUploadKeysRequest{} uploadReq := &api.PerformUploadKeysRequest{
DeviceID: device.ID,
UserID: device.UserID,
}
if r.DeviceKeys != nil { if r.DeviceKeys != nil {
uploadReq.DeviceKeys = []api.DeviceKeys{ uploadReq.DeviceKeys = []api.DeviceKeys{
{ {

View file

@ -91,7 +91,6 @@ func GetAvatarURL(
} }
// SetAvatarURL implements PUT /profile/{userID}/avatar_url // SetAvatarURL implements PUT /profile/{userID}/avatar_url
// nolint:gocyclo
func SetAvatarURL( func SetAvatarURL(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB accounts.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
@ -209,7 +208,6 @@ func GetDisplayName(
} }
// SetDisplayName implements PUT /profile/{userID}/displayname // SetDisplayName implements PUT /profile/{userID}/displayname
// nolint:gocyclo
func SetDisplayName( func SetDisplayName(
req *http.Request, accountDB accounts.Database, req *http.Request, accountDB accounts.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,

View file

@ -161,7 +161,6 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// state to see if there is an event with that type and state key, if there // state to see if there is an event with that type and state key, if there
// is then (by default) we return the content, otherwise a 404. // is then (by default) we return the content, otherwise a 404.
// If eventFormat=true, sends the whole event else just the content. // If eventFormat=true, sends the whole event else just the content.
// nolint:gocyclo
func OnIncomingStateTypeRequest( func OnIncomingStateTypeRequest(
ctx context.Context, device *userapi.Device, rsAPI api.RoomserverInternalAPI, ctx context.Context, device *userapi.Device, rsAPI api.RoomserverInternalAPI,
roomID, evType, stateKey string, eventFormat bool, roomID, evType, stateKey string, eventFormat bool,

View file

@ -52,7 +52,6 @@ var (
instancePeer = flag.String("peer", "", "an internet Yggdrasil peer to connect to") instancePeer = flag.String("peer", "", "an internet Yggdrasil peer to connect to")
) )
// nolint:gocyclo
func main() { func main() {
flag.Parse() flag.Parse()
internal.SetupPprof() internal.SetupPprof()

View file

@ -73,7 +73,6 @@ func (n *Node) DialerContext(ctx context.Context, network, address string) (net.
return n.Dialer(network, address) return n.Dialer(network, address)
} }
// nolint:gocyclo
func Setup(instanceName, storageDirectory string) (*Node, error) { func Setup(instanceName, storageDirectory string) (*Node, error) {
n := &Node{ n := &Node{
core: &yggdrasil.Core{}, core: &yggdrasil.Core{},

View file

@ -128,7 +128,6 @@ func (n *Node) Dial(network, address string) (net.Conn, error) {
} }
// Implements http.Transport.DialContext // Implements http.Transport.DialContext
// nolint:gocyclo
func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
s, ok1 := n.sessions.Load(address) s, ok1 := n.sessions.Load(address)
session, ok2 := s.(*session) session, ok2 := s.(*session)

View file

@ -20,7 +20,6 @@ var requestFrom = flag.String("from", "", "the server name that the request shou
var requestKey = flag.String("key", "matrix_key.pem", "the private key to use when signing the request") var requestKey = flag.String("key", "matrix_key.pem", "the private key to use when signing the request")
var requestPost = flag.Bool("post", false, "send a POST request instead of GET (pipe input into stdin or type followed by Ctrl-D)") var requestPost = flag.Bool("post", false, "send a POST request instead of GET (pipe input into stdin or type followed by Ctrl-D)")
// nolint:gocyclo
func main() { func main() {
flag.Parse() flag.Parse()

View file

@ -8,7 +8,6 @@ import (
"strconv" "strconv"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup"
@ -25,7 +24,6 @@ import (
var roomVersion = flag.String("roomversion", "5", "the room version to parse events as") var roomVersion = flag.String("roomversion", "5", "the room version to parse events as")
// nolint:gocyclo
func main() { func main() {
ctx := context.Background() ctx := context.Background()
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
@ -105,7 +103,7 @@ func main() {
} }
fmt.Println("Resolving state") fmt.Println("Resolving state")
resolved, err := state.ResolveConflictsAdhoc( resolved, err := gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), gomatrixserverlib.RoomVersion(*roomVersion),
events, events,
authEvents, authEvents,

View file

@ -109,7 +109,7 @@ On macOS, omit `sudo -u postgres` from the below commands.
* If you want to run each Dendrite component with its own database: * If you want to run each Dendrite component with its own database:
```bash ```bash
for i in mediaapi syncapi roomserver signingkeyserver federationsender appservice keyserver userapi_account userapi_device naffka; do for i in mediaapi syncapi roomserver signingkeyserver federationsender appservice keyserver userapi_accounts userapi_devices naffka; do
sudo -u postgres createdb -O dendrite dendrite_$i sudo -u postgres createdb -O dendrite dendrite_$i
done done
``` ```

View file

@ -29,7 +29,6 @@ import (
) )
// MakeJoin implements the /make_join API // MakeJoin implements the /make_join API
// nolint:gocyclo
func MakeJoin( func MakeJoin(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
@ -161,7 +160,6 @@ func MakeJoin(
// SendJoin implements the /send_join API // SendJoin implements the /send_join API
// The make-join send-join dance makes much more sense as a single // The make-join send-join dance makes much more sense as a single
// flow so the cyclomatic complexity is high: // flow so the cyclomatic complexity is high:
// nolint:gocyclo
func SendJoin( func SendJoin(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,

View file

@ -25,7 +25,6 @@ import (
) )
// MakeLeave implements the /make_leave API // MakeLeave implements the /make_leave API
// nolint:gocyclo
func MakeLeave( func MakeLeave(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,
@ -118,7 +117,6 @@ func MakeLeave(
} }
// SendLeave implements the /send_leave API // SendLeave implements the /send_leave API
// nolint:gocyclo
func SendLeave( func SendLeave(
httpReq *http.Request, httpReq *http.Request,
request *gomatrixserverlib.FederationRequest, request *gomatrixserverlib.FederationRequest,

View file

@ -111,7 +111,6 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO
} }
// due to lots of switches // due to lots of switches
// nolint:gocyclo
func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) {
avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""}
nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""}

View file

@ -102,11 +102,13 @@ func Send(
type txnReq struct { type txnReq struct {
gomatrixserverlib.Transaction gomatrixserverlib.Transaction
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
eduAPI eduserverAPI.EDUServerInputAPI eduAPI eduserverAPI.EDUServerInputAPI
keyAPI keyapi.KeyInternalAPI keyAPI keyapi.KeyInternalAPI
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
federation txnFederationClient federation txnFederationClient
servers []gomatrixserverlib.ServerName
serversMutex sync.RWMutex
// local cache of events for auth checks, etc - this may include events // local cache of events for auth checks, etc - this may include events
// which the roomserver is unaware of. // which the roomserver is unaware of.
haveEvents map[string]*gomatrixserverlib.HeaderedEvent haveEvents map[string]*gomatrixserverlib.HeaderedEvent
@ -277,7 +279,6 @@ func (t *txnReq) haveEventIDs() map[string]bool {
return result return result
} }
// nolint:gocyclo
func (t *txnReq) processEDUs(ctx context.Context) { func (t *txnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs { for _, e := range t.EDUs {
switch e.Type { switch e.Type {
@ -404,16 +405,21 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
} }
func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName { func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName {
servers := []gomatrixserverlib.ServerName{t.Origin} t.serversMutex.Lock()
defer t.serversMutex.Unlock()
if t.servers != nil {
return t.servers
}
t.servers = []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{ serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID, RoomID: roomID,
} }
serverRes := &api.QueryServerJoinedToRoomResponse{} serverRes := &api.QueryServerJoinedToRoomResponse{}
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...) t.servers = append(t.servers, serverRes.ServerNames...)
util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(servers), roomID) util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(t.servers), roomID)
} }
return servers return t.servers
} }
func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error {
@ -482,14 +488,10 @@ func (t *txnReq) retrieveMissingAuthEvents(
missingAuthEvents[missingAuthEventID] = struct{}{} missingAuthEvents[missingAuthEventID] = struct{}{}
} }
servers := t.getServers(ctx, e.RoomID())
if len(servers) > 5 {
servers = servers[:5]
}
withNextEvent: withNextEvent:
for missingAuthEventID := range missingAuthEvents { for missingAuthEventID := range missingAuthEvents {
withNextServer: withNextServer:
for _, server := range servers { for _, server := range t.getServers(ctx, e.RoomID()) {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil { if err != nil {
@ -537,7 +539,6 @@ func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserv
return gomatrixserverlib.Allowed(e, &authUsingState) return gomatrixserverlib.Allowed(e, &authUsingState)
} }
// nolint:gocyclo
func (t *txnReq) processEventWithMissingState(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { func (t *txnReq) processEventWithMissingState(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error {
// Do this with a fresh context, so that we keep working even if the // Do this with a fresh context, so that we keep working even if the
// original request times out. With any luck, by the time the remote // original request times out. With any luck, by the time the remote
@ -692,13 +693,8 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err) return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err)
} }
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// fetch the event we're missing and add it to the pile // fetch the event we're missing and add it to the pile
h, err := t.lookupEvent(ctx, roomVersion, eventID, false, servers) h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return respState, false, nil return respState, false, nil
@ -740,11 +736,10 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
t.haveEvents[ev.EventID()] = res.StateEvents[i] t.haveEvents[ev.EventID()] = res.StateEvents[i]
} }
var authEvents []*gomatrixserverlib.Event var authEvents []*gomatrixserverlib.Event
missingAuthEvents := make(map[string]bool) missingAuthEvents := map[string]bool{}
for _, ev := range res.StateEvents { for _, ev := range res.StateEvents {
for _, ae := range ev.AuthEventIDs() { for _, ae := range ev.AuthEventIDs() {
aev, ok := t.haveEvents[ae] if aev, ok := t.haveEvents[ae]; ok {
if ok {
authEvents = append(authEvents, aev.Unwrap()) authEvents = append(authEvents, aev.Unwrap())
} else { } else {
missingAuthEvents[ae] = true missingAuthEvents[ae] = true
@ -753,27 +748,28 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
} }
// QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't
// have stored the event. // have stored the event.
var missingEventList []string if len(missingAuthEvents) > 0 {
for evID := range missingAuthEvents { var missingEventList []string
missingEventList = append(missingEventList, evID) for evID := range missingAuthEvents {
} missingEventList = append(missingEventList, evID)
queryReq := api.QueryEventsByIDRequest{ }
EventIDs: missingEventList, queryReq := api.QueryEventsByIDRequest{
} EventIDs: missingEventList,
util.GetLogger(ctx).Infof("Fetching missing auth events: %v", missingEventList) }
var queryRes api.QueryEventsByIDResponse util.GetLogger(ctx).Infof("Fetching missing auth events: %v", missingEventList)
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { var queryRes api.QueryEventsByIDResponse
return nil if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
} return nil
for i := range queryRes.Events { }
evID := queryRes.Events[i].EventID() for i := range queryRes.Events {
t.haveEvents[evID] = queryRes.Events[i] evID := queryRes.Events[i].EventID()
authEvents = append(authEvents, queryRes.Events[i].Unwrap()) t.haveEvents[evID] = queryRes.Events[i]
authEvents = append(authEvents, queryRes.Events[i].Unwrap())
}
} }
evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents)
return &gomatrixserverlib.RespState{ return &gomatrixserverlib.RespState{
StateEvents: evs, StateEvents: gomatrixserverlib.UnwrapEventHeaders(res.StateEvents),
AuthEvents: authEvents, AuthEvents: authEvents,
} }
} }
@ -805,11 +801,7 @@ retryAllowedState:
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil {
switch missing := err.(type) { switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError: case gomatrixserverlib.MissingAuthEventError:
servers := t.getServers(ctx, backwardsExtremity.RoomID()) h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
if len(servers) > 5 {
servers = servers[:5]
}
h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true, servers)
switch err2.(type) { switch err2.(type) {
case verifySigError: case verifySigError:
return &gomatrixserverlib.RespState{ return &gomatrixserverlib.RespState{
@ -838,7 +830,6 @@ retryAllowedState:
// begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events // begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events
// This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns.
// This means that we may recursively call this function, as we spider back up prev_events. // This means that we may recursively call this function, as we spider back up prev_events.
// nolint:gocyclo
func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, err error) { func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, err error) {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e})
@ -857,17 +848,8 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
latestEvents[i] = res.LatestEvents[i].EventID latestEvents[i] = res.LatestEvents[i].EventID
} }
servers := []gomatrixserverlib.ServerName{t.Origin}
serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: e.RoomID(),
}
serverRes := &api.QueryServerJoinedToRoomResponse{}
if err = t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(servers))
}
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
servers := t.getServers(ctx, e.RoomID())
for _, server := range servers { for _, server := range servers {
var m gomatrixserverlib.RespMissingEvents var m gomatrixserverlib.RespMissingEvents
if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{
@ -950,7 +932,6 @@ func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID
return &state, nil return &state, nil
} }
// nolint:gocyclo
func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
*gomatrixserverlib.RespState, error) { *gomatrixserverlib.RespState, error) {
util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID) util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID)
@ -1015,12 +996,6 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
"concurrent_requests": concurrentRequests, "concurrent_requests": concurrentRequests,
}).Info("Fetching missing state at event") }).Info("Fetching missing state at event")
// Get a list of servers to fetch from.
servers := t.getServers(ctx, roomID)
if len(servers) > 5 {
servers = servers[:5]
}
// Create a queue containing all of the missing event IDs that we want // Create a queue containing all of the missing event IDs that we want
// to retrieve. // to retrieve.
pending := make(chan string, missingCount) pending := make(chan string, missingCount)
@ -1046,7 +1021,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
// Define what we'll do in order to fetch the missing event ID. // Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) { fetch := func(missingEventID string) {
var h *gomatrixserverlib.HeaderedEvent var h *gomatrixserverlib.HeaderedEvent
h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false, servers) h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
switch err.(type) { switch err.(type) {
case verifySigError: case verifySigError:
return return
@ -1112,7 +1087,7 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat
return &respState, nil return &respState, nil
} }
func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool, servers []gomatrixserverlib.ServerName) (*gomatrixserverlib.HeaderedEvent, error) { func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) {
if localFirst { if localFirst {
// fetch from the roomserver // fetch from the roomserver
queryReq := api.QueryEventsByIDRequest{ queryReq := api.QueryEventsByIDRequest{
@ -1127,6 +1102,7 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
} }
var event *gomatrixserverlib.Event var event *gomatrixserverlib.Event
found := false found := false
servers := t.getServers(ctx, roomID)
for _, serverName := range servers { for _, serverName := range servers {
txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) txn, err := t.federation.GetEvent(ctx, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 { if err != nil || len(txn.PDUs) == 0 {

View file

@ -46,6 +46,7 @@ const (
// ensures that only one request is in flight to a given destination // ensures that only one request is in flight to a given destination
// at a time. // at a time.
type destinationQueue struct { type destinationQueue struct {
queues *OutgoingQueues
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing *SigningInfo signing *SigningInfo
@ -172,7 +173,6 @@ func (oq *destinationQueue) wakeQueueIfNeeded() {
// getPendingFromDatabase will look at the database and see if // getPendingFromDatabase will look at the database and see if
// there are any persisted events that haven't been sent to this // there are any persisted events that haven't been sent to this
// destination yet. If so, they will be queued up. // destination yet. If so, they will be queued up.
// nolint:gocyclo
func (oq *destinationQueue) getPendingFromDatabase() { func (oq *destinationQueue) getPendingFromDatabase() {
// Check to see if there's anything to do for this server // Check to see if there's anything to do for this server
// in the database. // in the database.
@ -237,7 +237,6 @@ func (oq *destinationQueue) getPendingFromDatabase() {
} }
// backgroundSend is the worker goroutine for sending events. // backgroundSend is the worker goroutine for sending events.
// nolint:gocyclo
func (oq *destinationQueue) backgroundSend() { func (oq *destinationQueue) backgroundSend() {
// Check if a worker is already running, and if it isn't, then // Check if a worker is already running, and if it isn't, then
// mark it as started. // mark it as started.
@ -246,6 +245,7 @@ func (oq *destinationQueue) backgroundSend() {
} }
destinationQueueRunning.Inc() destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec() defer destinationQueueRunning.Dec()
defer oq.queues.clearQueue(oq)
defer oq.running.Store(false) defer oq.running.Store(false)
// Mark the queue as overflowed, so we will consult the database // Mark the queue as overflowed, so we will consult the database
@ -351,7 +351,6 @@ func (oq *destinationQueue) backgroundSend() {
// nextTransaction creates a new transaction from the pending event // nextTransaction creates a new transaction from the pending event
// queue and sends it. Returns true if a transaction was sent or // queue and sends it. Returns true if a transaction was sent or
// false otherwise. // false otherwise.
// nolint:gocyclo
func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU, pdus []*queuedPDU,
edus []*queuedEDU, edus []*queuedEDU,
@ -444,7 +443,7 @@ func (oq *destinationQueue) nextTransaction(
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"destination": oq.destination, "destination": oq.destination,
log.ErrorKey: err, log.ErrorKey: err,
}).Infof("Failed to send transaction %q", t.TransactionID) }).Debugf("Failed to send transaction %q", t.TransactionID)
return false, 0, 0, err return false, 0, 0, err
} }
} }

View file

@ -120,7 +120,7 @@ func NewOutgoingQueues(
log.WithError(err).Error("Failed to get EDU server names for destination queue hydration") log.WithError(err).Error("Failed to get EDU server names for destination queue hydration")
} }
for serverName := range serverNames { for serverName := range serverNames {
if queue := queues.getQueue(serverName); !queue.statistics.Blacklisted() { if queue := queues.getQueue(serverName); queue != nil {
queue.wakeQueueIfNeeded() queue.wakeQueueIfNeeded()
} }
} }
@ -148,12 +148,16 @@ type queuedEDU struct {
} }
func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue {
if oqs.statistics.ForServer(destination).Blacklisted() {
return nil
}
oqs.queuesMutex.Lock() oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock() defer oqs.queuesMutex.Unlock()
oq := oqs.queues[destination] oq, ok := oqs.queues[destination]
if oq == nil { if !ok || oq != nil {
destinationQueueTotal.Inc() destinationQueueTotal.Inc()
oq = &destinationQueue{ oq = &destinationQueue{
queues: oqs,
db: oqs.db, db: oqs.db,
process: oqs.process, process: oqs.process,
rsAPI: oqs.rsAPI, rsAPI: oqs.rsAPI,
@ -170,6 +174,14 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
return oq return oq
} }
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock()
delete(oqs.queues, oq.destination)
destinationQueueTotal.Dec()
}
type ErrorFederationDisabled struct { type ErrorFederationDisabled struct {
Message string Message string
} }
@ -236,7 +248,9 @@ func (oqs *OutgoingQueues) SendEvent(
} }
for destination := range destmap { for destination := range destmap {
oqs.getQueue(destination).sendEvent(ev, nid) if queue := oqs.getQueue(destination); queue != nil {
queue.sendEvent(ev, nid)
}
} }
return nil return nil
@ -306,7 +320,9 @@ func (oqs *OutgoingQueues) SendEDU(
} }
for destination := range destmap { for destination := range destmap {
oqs.getQueue(destination).sendEDU(e, nid) if queue := oqs.getQueue(destination); queue != nil {
queue.sendEDU(e, nid)
}
} }
return nil return nil
@ -317,9 +333,7 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled { if oqs.disabled {
return return
} }
q := oqs.getQueue(srv) if queue := oqs.getQueue(srv); queue != nil {
if q == nil { queue.wakeQueueIfNeeded()
return
} }
q.wakeQueueIfNeeded()
} }

View file

@ -0,0 +1,46 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveRoomsTable(tx *sql.Tx) error {
// We can't reverse this.
return nil
}

View file

@ -1,104 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const roomSchema = `
CREATE TABLE IF NOT EXISTS federationsender_rooms (
-- The string ID of the room
room_id TEXT PRIMARY KEY,
-- The most recent event state by the room server.
-- We can use this to tell if our view of the room state has become
-- desynchronised.
last_event_id TEXT NOT NULL
);`
const insertRoomSQL = "" +
"INSERT INTO federationsender_rooms (room_id, last_event_id) VALUES ($1, '')" +
" ON CONFLICT DO NOTHING"
const selectRoomForUpdateSQL = "" +
"SELECT last_event_id FROM federationsender_rooms WHERE room_id = $1 FOR UPDATE"
const updateRoomSQL = "" +
"UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1"
type roomStatements struct {
db *sql.DB
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
}
func NewPostgresRoomsTable(db *sql.DB) (s *roomStatements, err error) {
s = &roomStatements{
db: db,
}
_, err = s.db.Exec(roomSchema)
if err != nil {
return
}
if s.insertRoomStmt, err = s.db.Prepare(insertRoomSQL); err != nil {
return
}
if s.selectRoomForUpdateStmt, err = s.db.Prepare(selectRoomForUpdateSQL); err != nil {
return
}
if s.updateRoomStmt, err = s.db.Prepare(updateRoomSQL); err != nil {
return
}
return
}
// insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) InsertRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
// The row must already exist in the table. Callers can ensure that the row
// exists by calling insertRoom first.
func (s *roomStatements) SelectRoomForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (string, error) {
var lastEventID string
stmt := sqlutil.TxStmt(txn, s.selectRoomForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
if err != nil {
return "", err
}
return lastEventID, nil
}
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction.
func (s *roomStatements) UpdateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err
}

View file

@ -18,6 +18,7 @@ package postgres
import ( import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/federationsender/storage/postgres/deltas"
"github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -56,10 +57,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
if err != nil { if err != nil {
return nil, err return nil, err
} }
rooms, err := NewPostgresRoomsTable(d.db)
if err != nil {
return nil, err
}
blacklist, err := NewPostgresBlacklistTable(d.db) blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -72,6 +69,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadRemoveRoomsTable(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Cache: cache, Cache: cache,
@ -80,7 +82,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
FederationSenderQueuePDUs: queuePDUs, FederationSenderQueuePDUs: queuePDUs,
FederationSenderQueueEDUs: queueEDUs, FederationSenderQueueEDUs: queueEDUs,
FederationSenderQueueJSON: queueJSON, FederationSenderQueueJSON: queueJSON,
FederationSenderRooms: rooms,
FederationSenderBlacklist: blacklist, FederationSenderBlacklist: blacklist,
FederationSenderInboundPeeks: inboundPeeks, FederationSenderInboundPeeks: inboundPeeks,
FederationSenderOutboundPeeks: outboundPeeks, FederationSenderOutboundPeeks: outboundPeeks,

View file

@ -34,7 +34,6 @@ type Database struct {
FederationSenderQueueEDUs tables.FederationSenderQueueEDUs FederationSenderQueueEDUs tables.FederationSenderQueueEDUs
FederationSenderQueueJSON tables.FederationSenderQueueJSON FederationSenderQueueJSON tables.FederationSenderQueueJSON
FederationSenderJoinedHosts tables.FederationSenderJoinedHosts FederationSenderJoinedHosts tables.FederationSenderJoinedHosts
FederationSenderRooms tables.FederationSenderRooms
FederationSenderBlacklist tables.FederationSenderBlacklist FederationSenderBlacklist tables.FederationSenderBlacklist
FederationSenderOutboundPeeks tables.FederationSenderOutboundPeeks FederationSenderOutboundPeeks tables.FederationSenderOutboundPeeks
FederationSenderInboundPeeks tables.FederationSenderInboundPeeks FederationSenderInboundPeeks tables.FederationSenderInboundPeeks
@ -64,29 +63,6 @@ func (d *Database) UpdateRoom(
removeHosts []string, removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) { ) (joinedHosts []types.JoinedHost, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID)
if err != nil {
return err
}
lastSentEventID, err := d.FederationSenderRooms.SelectRoomForUpdate(ctx, txn, roomID)
if err != nil {
return err
}
if lastSentEventID == newEventID {
// We've handled this message before, so let's just ignore it.
// We can only get a duplicate for the last message we processed,
// so its enough just to compare the newEventID with lastSentEventID
return nil
}
if lastSentEventID != "" && lastSentEventID != oldEventID {
return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
}
}
joinedHosts, err = d.FederationSenderJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) joinedHosts, err = d.FederationSenderJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
if err != nil { if err != nil {
return err return err
@ -101,7 +77,7 @@ func (d *Database) UpdateRoom(
if err = d.FederationSenderJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil { if err = d.FederationSenderJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil {
return err return err
} }
return d.FederationSenderRooms.UpdateRoom(ctx, txn, roomID, newEventID) return nil
}) })
return return
} }

View file

@ -0,0 +1,46 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveRoomsTable(tx *sql.Tx) error {
// We can't reverse this.
return nil
}

View file

@ -1,105 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const roomSchema = `
CREATE TABLE IF NOT EXISTS federationsender_rooms (
-- The string ID of the room
room_id TEXT PRIMARY KEY,
-- The most recent event state by the room server.
-- We can use this to tell if our view of the room state has become
-- desynchronised.
last_event_id TEXT NOT NULL
);`
const insertRoomSQL = "" +
"INSERT INTO federationsender_rooms (room_id, last_event_id) VALUES ($1, '')" +
" ON CONFLICT DO NOTHING"
const selectRoomForUpdateSQL = "" +
"SELECT last_event_id FROM federationsender_rooms WHERE room_id = $1"
const updateRoomSQL = "" +
"UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1"
type roomStatements struct {
db *sql.DB
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
}
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
s = &roomStatements{
db: db,
}
_, err = db.Exec(roomSchema)
if err != nil {
return
}
if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil {
return
}
if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil {
return
}
if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil {
return
}
return
}
// insertRoom inserts the room if it didn't already exist.
// If the room didn't exist then last_event_id is set to the empty string.
func (s *roomStatements) InsertRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
// The row must already exist in the table. Callers can ensure that the row
// exists by calling insertRoom first.
func (s *roomStatements) SelectRoomForUpdate(
ctx context.Context, txn *sql.Tx, roomID string,
) (string, error) {
var lastEventID string
stmt := sqlutil.TxStmt(txn, s.selectRoomForUpdateStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
if err != nil {
return "", err
}
return lastEventID, nil
}
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
// have already been called earlier within the transaction.
func (s *roomStatements) UpdateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err
}

View file

@ -21,6 +21,7 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -46,10 +47,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
if err != nil { if err != nil {
return nil, err return nil, err
} }
rooms, err := NewSQLiteRoomsTable(d.db)
if err != nil {
return nil, err
}
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db) queuePDUs, err := NewSQLiteQueuePDUsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -74,6 +71,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadRemoveRoomsTable(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Cache: cache, Cache: cache,
@ -82,7 +84,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
FederationSenderQueuePDUs: queuePDUs, FederationSenderQueuePDUs: queuePDUs,
FederationSenderQueueEDUs: queueEDUs, FederationSenderQueueEDUs: queueEDUs,
FederationSenderQueueJSON: queueJSON, FederationSenderQueueJSON: queueJSON,
FederationSenderRooms: rooms,
FederationSenderBlacklist: blacklist, FederationSenderBlacklist: blacklist,
FederationSenderOutboundPeeks: outboundPeeks, FederationSenderOutboundPeeks: outboundPeeks,
FederationSenderInboundPeeks: inboundPeeks, FederationSenderInboundPeeks: inboundPeeks,

View file

@ -56,12 +56,6 @@ type FederationSenderJoinedHosts interface {
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
} }
type FederationSenderRooms interface {
InsertRoom(ctx context.Context, txn *sql.Tx, roomID string) error
SelectRoomForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (string, error)
UpdateRoom(ctx context.Context, txn *sql.Tx, roomID, lastEventID string) error
}
type FederationSenderBlacklist interface { type FederationSenderBlacklist interface {
InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error)

View file

@ -15,8 +15,6 @@
package types package types
import ( import (
"fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -34,22 +32,6 @@ func (s ServerNames) Len() int { return len(s) }
func (s ServerNames) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (s ServerNames) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s ServerNames) Less(i, j int) bool { return s[i] < s[j] } func (s ServerNames) Less(i, j int) bool { return s[i] < s[j] }
// A EventIDMismatchError indicates that we have got out of sync with the
// room server.
type EventIDMismatchError struct {
// The event ID we have stored in our local database.
DatabaseID string
// The event ID received from the room server.
RoomServerID string
}
func (e EventIDMismatchError) Error() string {
return fmt.Sprintf(
"mismatched last sent event ID: had %q in database got %q from room server",
e.DatabaseID, e.RoomServerID,
)
}
// tracks peeks we're performing on another server over federation // tracks peeks we're performing on another server over federation
type OutboundPeek struct { type OutboundPeek struct {
PeekID string PeekID string

52
go.mod
View file

@ -2,48 +2,46 @@ module github.com/matrix-org/dendrite
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Shopify/sarama v1.27.0 github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/Shopify/sarama v1.28.0
github.com/gologme/log v1.2.0 github.com/gologme/log v1.2.0
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/lib/pq v1.8.0 github.com/lib/pq v1.9.0
github.com/libp2p/go-libp2p v0.11.0 github.com/libp2p/go-libp2p v0.13.0
github.com/libp2p/go-libp2p-circuit v0.3.1 github.com/libp2p/go-libp2p-circuit v0.4.0
github.com/libp2p/go-libp2p-core v0.6.1 github.com/libp2p/go-libp2p-core v0.8.3
github.com/libp2p/go-libp2p-gostream v0.2.1 github.com/libp2p/go-libp2p-gostream v0.3.1
github.com/libp2p/go-libp2p-http v0.1.5 github.com/libp2p/go-libp2p-http v0.2.0
github.com/libp2p/go-libp2p-kad-dht v0.9.0 github.com/libp2p/go-libp2p-kad-dht v0.11.1
github.com/libp2p/go-libp2p-pubsub v0.3.5 github.com/libp2p/go-libp2p-pubsub v0.4.1
github.com/libp2p/go-libp2p-record v0.1.3 github.com/libp2p/go-libp2p-record v0.1.3
github.com/libp2p/go-yamux v1.3.9 // indirect
github.com/lucas-clemente/quic-go v0.17.3 github.com/lucas-clemente/quic-go v0.17.3
github.com/matrix-org/dugong v0.0.0-20180820122854-51a565b5666b github.com/matrix-org/dugong v0.0.0-20180820122854-51a565b5666b
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20210129163316-dd4d53729ead github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.6
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
github.com/opentracing/opentracing-go v1.2.0 github.com/opentracing/opentracing-go v1.2.0
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pressly/goose v2.7.0-rc5+incompatible github.com/pressly/goose v2.7.0+incompatible
github.com/prometheus/client_golang v1.7.1 github.com/prometheus/client_golang v1.9.0
github.com/sirupsen/logrus v1.7.0 github.com/sirupsen/logrus v1.8.0
github.com/tidwall/gjson v1.6.7 github.com/tidwall/gjson v1.6.8
github.com/tidwall/sjson v1.1.4 github.com/tidwall/sjson v1.1.5
github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-client-go v2.25.0+incompatible
github.com/uber/jaeger-lib v2.2.0+incompatible github.com/uber/jaeger-lib v2.4.0+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20201006093556-760d9a7fd5ee github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20210218094457-e77ca8019daa
go.uber.org/atomic v1.6.0 go.uber.org/atomic v1.7.0
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
golang.org/x/net v0.0.0-20200528225125-3c3fba18258b golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect gopkg.in/h2non/bimg.v1 v1.1.5
gopkg.in/h2non/bimg.v1 v1.1.4 gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v2 v2.3.0
) )
go 1.13 go 1.13

585
go.sum

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,7 @@ package caching
import ( import (
"fmt" "fmt"
"time"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -72,6 +73,11 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
go cacheCleaner(
roomVersions, serverKeys, roomServerStateKeyNIDs,
roomServerEventTypeNIDs, roomServerRoomIDs,
roomInfos, federationEvents,
)
return &Caches{ return &Caches{
RoomVersions: roomVersions, RoomVersions: roomVersions,
ServerKeys: serverKeys, ServerKeys: serverKeys,
@ -83,6 +89,20 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) {
}, nil }, nil
} }
func cacheCleaner(caches ...*InMemoryLRUCachePartition) {
for {
time.Sleep(time.Minute)
for _, cache := range caches {
// Hold onto the last 10% of the cache entries, since
// otherwise a quiet period might cause us to evict all
// cache entries entirely.
if cache.lru.Len() > cache.maxEntries/10 {
cache.lru.RemoveOldest()
}
}
}
}
type InMemoryLRUCachePartition struct { type InMemoryLRUCachePartition struct {
name string name string
mutable bool mutable bool

View file

@ -17,7 +17,7 @@ var build string
const ( const (
VersionMajor = 0 VersionMajor = 0
VersionMinor = 3 VersionMinor = 3
VersionPatch = 8 VersionPatch = 11
VersionTag = "" // example: "rc1" VersionTag = "" // example: "rc1"
) )

View file

@ -108,6 +108,8 @@ type OneTimeKeysCount struct {
// PerformUploadKeysRequest is the request to PerformUploadKeys // PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct { type PerformUploadKeysRequest struct {
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys OneTimeKeys []OneTimeKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update

View file

@ -513,6 +513,23 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
} }
func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
}
}
if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
}
}
if counts != nil {
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
}
return
}
for _, key := range req.OneTimeKeys { for _, key := range req.OneTimeKeys {
// grab existing keys based on (user/device/algorithm/key ID) // grab existing keys based on (user/device/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
@ -521,9 +538,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
keyIDsWithAlgorithms[i] = keyIDWithAlgo keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++ i++
} }
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, key.UserID, key.DeviceID, keyIDsWithAlgorithms) existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
if err != nil { if err != nil {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time keys: " + err.Error(), Err: "failed to query existing one-time keys: " + err.Error(),
}) })
continue continue
@ -531,8 +548,8 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
for keyIDWithAlgo := range existingKeys { for keyIDWithAlgo := range existingKeys {
// if keys exist and the JSON doesn't match, error out as the key already exists // if keys exist and the JSON doesn't match, error out as the key already exists
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", key.UserID, key.DeviceID, keyIDWithAlgo), Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
}) })
continue continue
} }
@ -540,8 +557,8 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
// store one-time keys // store one-time keys
counts, err := a.DB.StoreOneTimeKeys(ctx, key) counts, err := a.DB.StoreOneTimeKeys(ctx, key)
if err != nil { if err != nil {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()), Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
}) })
continue continue
} }

View file

@ -109,7 +109,7 @@ func RemoveDir(dir types.Path, logger *log.Entry) {
// WriteTempFile writes to a new temporary file. // WriteTempFile writes to a new temporary file.
// The file is deleted if there was an error while writing. // The file is deleted if there was an error while writing.
func WriteTempFile( func WriteTempFile(
ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path, ctx context.Context, reqReader io.Reader, absBasePath config.Path,
) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) { ) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
size = -1 size = -1
logger := util.GetLogger(ctx) logger := util.GetLogger(ctx)
@ -124,18 +124,11 @@ func WriteTempFile(
} }
}() }()
// If the max_file_size_bytes configuration option is set to a positive
// number then limit the upload to that size. Otherwise, just read the
// whole file.
limitedReader := reqReader
if maxFileSizeBytes > 0 {
limitedReader = io.LimitReader(reqReader, int64(maxFileSizeBytes))
}
// Hash the file data. The hash will be returned. The hash is useful as a // Hash the file data. The hash will be returned. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct // method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. // integrity checks on the file data in the repository.
hasher := sha256.New() hasher := sha256.New()
teeReader := io.TeeReader(limitedReader, hasher) teeReader := io.TeeReader(reqReader, hasher)
bytesWritten, err := io.Copy(tmpFileWriter, teeReader) bytesWritten, err := io.Copy(tmpFileWriter, teeReader)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
RemoveDir(tmpDir, logger) RemoveDir(tmpDir, logger)

View file

@ -19,6 +19,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"mime" "mime"
"net/http" "net/http"
"net/url" "net/url"
@ -214,7 +215,7 @@ func (r *downloadRequest) doDownload(
ctx, r.MediaMetadata.MediaID, r.MediaMetadata.Origin, ctx, r.MediaMetadata.MediaID, r.MediaMetadata.Origin,
) )
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error querying the database") return nil, fmt.Errorf("db.GetMediaMetadata: %w", err)
} }
if mediaMetadata == nil { if mediaMetadata == nil {
if r.MediaMetadata.Origin == cfg.Matrix.ServerName { if r.MediaMetadata.Origin == cfg.Matrix.ServerName {
@ -253,16 +254,16 @@ func (r *downloadRequest) respondFromLocalFile(
) (*types.MediaMetadata, error) { ) (*types.MediaMetadata, error) {
filePath, err := fileutils.GetPathFromBase64Hash(r.MediaMetadata.Base64Hash, absBasePath) filePath, err := fileutils.GetPathFromBase64Hash(r.MediaMetadata.Base64Hash, absBasePath)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get file path from metadata") return nil, fmt.Errorf("fileutils.GetPathFromBase64Hash: %w", err)
} }
file, err := os.Open(filePath) file, err := os.Open(filePath)
defer file.Close() // nolint: errcheck, staticcheck, megacheck defer file.Close() // nolint: errcheck, staticcheck, megacheck
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to open file") return nil, fmt.Errorf("os.Open: %w", err)
} }
stat, err := file.Stat() stat, err := file.Stat()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to stat file") return nil, fmt.Errorf("file.Stat: %w", err)
} }
if r.MediaMetadata.FileSizeBytes > 0 && int64(r.MediaMetadata.FileSizeBytes) != stat.Size() { if r.MediaMetadata.FileSizeBytes > 0 && int64(r.MediaMetadata.FileSizeBytes) != stat.Size() {
@ -324,7 +325,7 @@ func (r *downloadRequest) respondFromLocalFile(
w.Header().Set("Content-Security-Policy", contentSecurityPolicy) w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if _, err := io.Copy(w, responseFile); err != nil { if _, err := io.Copy(w, responseFile); err != nil {
return nil, errors.Wrap(err, "failed to copy from cache") return nil, fmt.Errorf("io.Copy: %w", err)
} }
return responseMetadata, nil return responseMetadata, nil
} }
@ -421,7 +422,7 @@ func (r *downloadRequest) getThumbnailFile(
ctx, r.MediaMetadata.MediaID, r.MediaMetadata.Origin, ctx, r.MediaMetadata.MediaID, r.MediaMetadata.Origin,
) )
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "error looking up thumbnails") return nil, nil, fmt.Errorf("db.GetThumbnails: %w", err)
} }
// If we get a thumbnailSize, a pre-generated thumbnail would be best but it is not yet generated. // If we get a thumbnailSize, a pre-generated thumbnail would be best but it is not yet generated.
@ -459,12 +460,12 @@ func (r *downloadRequest) getThumbnailFile(
thumbFile, err := os.Open(string(thumbPath)) thumbFile, err := os.Open(string(thumbPath))
if err != nil { if err != nil {
thumbFile.Close() // nolint: errcheck thumbFile.Close() // nolint: errcheck
return nil, nil, errors.Wrap(err, "failed to open file") return nil, nil, fmt.Errorf("os.Open: %w", err)
} }
thumbStat, err := thumbFile.Stat() thumbStat, err := thumbFile.Stat()
if err != nil { if err != nil {
thumbFile.Close() // nolint: errcheck thumbFile.Close() // nolint: errcheck
return nil, nil, errors.Wrap(err, "failed to stat file") return nil, nil, fmt.Errorf("thumbFile.Stat: %w", err)
} }
if types.FileSizeBytes(thumbStat.Size()) != thumbnail.MediaMetadata.FileSizeBytes { if types.FileSizeBytes(thumbStat.Size()) != thumbnail.MediaMetadata.FileSizeBytes {
thumbFile.Close() // nolint: errcheck thumbFile.Close() // nolint: errcheck
@ -491,7 +492,7 @@ func (r *downloadRequest) generateThumbnail(
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger, activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,
) )
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error creating thumbnail") return nil, fmt.Errorf("thumbnailer.GenerateThumbnail: %w", err)
} }
if busy { if busy {
return nil, nil return nil, nil
@ -502,7 +503,7 @@ func (r *downloadRequest) generateThumbnail(
thumbnailSize.Width, thumbnailSize.Height, thumbnailSize.ResizeMethod, thumbnailSize.Width, thumbnailSize.Height, thumbnailSize.ResizeMethod,
) )
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error looking up thumbnail") return nil, fmt.Errorf("db.GetThumbnail: %w", err)
} }
return thumbnail, nil return thumbnail, nil
} }
@ -543,7 +544,7 @@ func (r *downloadRequest) getRemoteFile(
ctx, r.MediaMetadata.MediaID, r.MediaMetadata.Origin, ctx, r.MediaMetadata.MediaID, r.MediaMetadata.Origin,
) )
if err != nil { if err != nil {
return errors.Wrap(err, "error querying the database.") return fmt.Errorf("db.GetMediaMetadata: %w", err)
} }
if mediaMetadata == nil { if mediaMetadata == nil {
@ -555,7 +556,7 @@ func (r *downloadRequest) getRemoteFile(
cfg.MaxThumbnailGenerators, cfg.MaxThumbnailGenerators,
) )
if err != nil { if err != nil {
return errors.Wrap(err, "error querying the database.") return fmt.Errorf("r.fetchRemoteFileAndStoreMetadata: %w", err)
} }
} else { } else {
// If we have a record, we can respond from the local file // If we have a record, we can respond from the local file
@ -673,6 +674,43 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil return nil
} }
func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
reader := *body
var contentLength int64
if contentLengthHeader != "" {
// A Content-Length header is provided. Let's try to parse it.
parsedLength, parseErr := strconv.ParseInt(contentLengthHeader, 10, 64)
if parseErr != nil {
r.Logger.WithError(parseErr).Warn("Failed to parse content length")
return 0, nil, fmt.Errorf("strconv.ParseInt: %w", parseErr)
}
if parsedLength > int64(maxFileSizeBytes) {
return 0, nil, fmt.Errorf(
"remote file size (%d bytes) exceeds locally configured max media size (%d bytes)",
parsedLength, maxFileSizeBytes,
)
}
// We successfully parsed the Content-Length, so we'll return a limited
// reader that restricts us to reading only up to this size.
reader = ioutil.NopCloser(io.LimitReader(*body, parsedLength))
contentLength = parsedLength
} else {
// Content-Length header is missing. If we have a maximum file size
// configured then we'll just make sure that the reader is limited to
// that size. We'll return a zero content length, but that's OK, since
// ultimately it will get rewritten later when the temp file is written
// to disk.
if maxFileSizeBytes > 0 {
reader = ioutil.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
}
contentLength = 0
}
return contentLength, reader, nil
}
func (r *downloadRequest) fetchRemoteFile( func (r *downloadRequest) fetchRemoteFile(
ctx context.Context, ctx context.Context,
client *gomatrixserverlib.Client, client *gomatrixserverlib.Client,
@ -692,16 +730,18 @@ func (r *downloadRequest) fetchRemoteFile(
} }
defer resp.Body.Close() // nolint: errcheck defer resp.Body.Close() // nolint: errcheck
// get metadata from request and set metadata on response // The reader returned here will be limited either by the Content-Length
contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) // and/or the configured maximum media size.
if err != nil { contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
r.Logger.WithError(err).Warn("Failed to parse content length") if parseErr != nil {
return "", false, errors.Wrap(err, "invalid response from remote server") return "", false, parseErr
} }
if contentLength > int64(maxFileSizeBytes) { if contentLength > int64(maxFileSizeBytes) {
// TODO: Bubble up this as a 413 // TODO: Bubble up this as a 413
return "", false, fmt.Errorf("remote file is too large (%v > %v bytes)", contentLength, maxFileSizeBytes) return "", false, fmt.Errorf("remote file is too large (%v > %v bytes)", contentLength, maxFileSizeBytes)
} }
r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength) r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength)
r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
@ -728,7 +768,7 @@ func (r *downloadRequest) fetchRemoteFile(
// method of deduplicating files to save storage, as well as a way to conduct // method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository. // integrity checks on the file data in the repository.
// Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, resp.Body, maxFileSizeBytes, absBasePath) hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reader, absBasePath)
if err != nil { if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{ r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": maxFileSizeBytes, "MaxFileSizeBytes": maxFileSizeBytes,
@ -747,7 +787,7 @@ func (r *downloadRequest) fetchRemoteFile(
// The database is the source of truth so we need to have moved the file first // The database is the source of truth so we need to have moved the file first
finalPath, duplicate, err := fileutils.MoveFileWithHashCheck(tmpDir, r.MediaMetadata, absBasePath, r.Logger) finalPath, duplicate, err := fileutils.MoveFileWithHashCheck(tmpDir, r.MediaMetadata, absBasePath, r.Logger)
if err != nil { if err != nil {
return "", false, errors.Wrap(err, "failed to move file") return "", false, fmt.Errorf("fileutils.MoveFileWithHashCheck: %w", err)
} }
if duplicate { if duplicate {
r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate") r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate")

View file

@ -147,7 +147,7 @@ func (r *uploadRequest) doUpload(
// r.storeFileAndMetadata(ctx, tmpDir, ...) // r.storeFileAndMetadata(ctx, tmpDir, ...)
// before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of // before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of
// nested function to guarantee either storage or cleanup. // nested function to guarantee either storage or cleanup.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath) hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath)
if err != nil { if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{ r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes, "MaxFileSizeBytes": *cfg.MaxFileSizeBytes,

View file

@ -172,7 +172,6 @@ func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, ro
// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the // PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the
// published room directory. // published room directory.
// due to lots of switches // due to lots of switches
// nolint:gocyclo
func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) {
avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""}
nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""}

View file

@ -270,7 +270,6 @@ func CheckServerAllowedToSeeEvent(
} }
// TODO: Remove this when we have tests to assert correctness of this function // TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo
func ScanEventTree( func ScanEventTree(
ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int, ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,

View file

@ -107,11 +107,21 @@ func (r *Inputer) updateMembership(
return updates, nil return updates, nil
} }
// In an ideal world, we shouldn't ever have "add" be nil and "remove" be
// set, as this implies that we're deleting a state event without replacing
// it (a thing that ordinarily shouldn't happen in Matrix). However, state
// resets are sadly a thing occasionally and we have to account for that.
// Beforehand there used to be a check here which stopped dead if we hit
// this scenario, but that meant that the membership table got out of sync
// after a state reset, often thinking that the user was still joined to
// the room even though the room state said otherwise, and this would prevent
// the user from being able to attempt to rejoin the room without modifying
// the database. So instead what we'll do is we'll just update the membership
// table to say that the user is "leave" and we'll use the old event to
// avoid nil pointer exceptions on the code path that follows.
if add == nil { if add == nil {
// This can happen when we have rejoined a room and suddenly we have a add = remove
// divergence between the former state and the new one. We don't want to newMembership = gomatrixserverlib.Leave
// act on removals and apparently there are no adds, so stop here.
return updates, nil
} }
mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add)) mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))

View file

@ -381,7 +381,6 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
// It returns a list of servers which can be queried for backfill requests. These servers // It returns a list of servers which can be queried for backfill requests. These servers
// will be servers that are in the room already. The entries at the beginning are preferred servers // will be servers that are in the room already. The entries at the beginning are preferred servers
// and will be tried first. An empty list will fail the request. // and will be tried first. An empty list will fail the request.
// nolint:gocyclo
func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName {
// eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use
// its successor, so look it up. // its successor, so look it up.

View file

@ -37,7 +37,6 @@ type Inviter struct {
Inputer *input.Inputer Inputer *input.Inputer
} }
// nolint:gocyclo
func (r *Inviter) PerformInvite( func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,

View file

@ -147,7 +147,6 @@ func (r *Joiner) performJoinRoomByAlias(
} }
// TODO: Break this function up a bit // TODO: Break this function up a bit
// nolint:gocyclo
func (r *Joiner) performJoinRoomByID( func (r *Joiner) performJoinRoomByID(
ctx context.Context, ctx context.Context,
req *api.PerformJoinRequest, req *api.PerformJoinRequest,

View file

@ -49,7 +49,6 @@ func (r *Queryer) QueryLatestEventsAndState(
} }
// QueryStateAfterEvents implements api.RoomserverInternalAPI // QueryStateAfterEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *Queryer) QueryStateAfterEvents( func (r *Queryer) QueryStateAfterEvents(
ctx context.Context, ctx context.Context,
request *api.QueryStateAfterEventsRequest, request *api.QueryStateAfterEventsRequest,
@ -112,7 +111,7 @@ func (r *Queryer) QueryStateAfterEvents(
return fmt.Errorf("getAuthChain: %w", err) return fmt.Errorf("getAuthChain: %w", err)
} }
stateEvents, err = state.ResolveConflictsAdhoc(info.RoomVersion, stateEvents, authEvents) stateEvents, err = gomatrixserverlib.ResolveConflicts(info.RoomVersion, stateEvents, authEvents)
if err != nil { if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
} }
@ -372,7 +371,6 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
} }
// QueryMissingEvents implements api.RoomserverInternalAPI // QueryMissingEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *Queryer) QueryMissingEvents( func (r *Queryer) QueryMissingEvents(
ctx context.Context, ctx context.Context,
request *api.QueryMissingEventsRequest, request *api.QueryMissingEventsRequest,
@ -469,7 +467,7 @@ func (r *Queryer) QueryStateAndAuthChain(
} }
if request.ResolveState { if request.ResolveState {
if stateEvents, err = state.ResolveConflictsAdhoc( if stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, stateEvents, authEvents, info.RoomVersion, stateEvents, authEvents,
); err != nil { ); err != nil {
return err return err

View file

@ -683,79 +683,6 @@ func (v *StateResolution) calculateStateAfterManyEvents(
return return
} }
// ResolveConflictsAdhoc is a helper function to assist the query API in
// performing state resolution when requested. This is a different code
// path to the rest of state.go because this assumes you already have
// gomatrixserverlib.Event objects and not just a bunch of NIDs like
// elsewhere in the state resolution.
// TODO: Some of this can possibly be deduplicated
func ResolveConflictsAdhoc(
version gomatrixserverlib.RoomVersion,
events []*gomatrixserverlib.Event,
authEvents []*gomatrixserverlib.Event,
) ([]*gomatrixserverlib.Event, error) {
type stateKeyTuple struct {
Type string
StateKey string
}
// Prepare our data structures.
eventMap := make(map[stateKeyTuple][]*gomatrixserverlib.Event)
var conflicted, notConflicted, resolved []*gomatrixserverlib.Event
// Run through all of the events that we were given and sort them
// into a map, sorted by (event_type, state_key) tuple. This means
// that we can easily spot events that are "conflicted", e.g.
// there are duplicate values for the same tuple key.
for _, event := range events {
if event.StateKey() == nil {
// Ignore events that are not state events.
continue
}
// Append the events if there is already a conflicted list for
// this tuple key, create it if not.
tuple := stateKeyTuple{event.Type(), *event.StateKey()}
eventMap[tuple] = append(eventMap[tuple], event)
}
// Split out the events in the map into conflicted and unconflicted
// buckets. The conflicted events will be ran through state res,
// whereas unconfliced events will always going to appear in the
// final resolved state.
for _, list := range eventMap {
if len(list) > 1 {
conflicted = append(conflicted, list...)
} else {
notConflicted = append(notConflicted, list...)
}
}
// Work out which state resolution algorithm we want to run for
// the room version.
stateResAlgo, err := version.StateResAlgorithm()
if err != nil {
return nil, err
}
switch stateResAlgo {
case gomatrixserverlib.StateResV1:
// Currently state res v1 doesn't handle unconflicted events
// for us, like state res v2 does, so we will need to add the
// unconflicted events into the state ourselves.
// TODO: Fix state res v1 so this is handled for the caller.
resolved = gomatrixserverlib.ResolveStateConflicts(conflicted, authEvents)
resolved = append(resolved, notConflicted...)
case gomatrixserverlib.StateResV2:
// TODO: auth difference here?
resolved = gomatrixserverlib.ResolveStateConflictsV2(conflicted, notConflicted, authEvents, authEvents)
default:
return nil, fmt.Errorf("unsupported state resolution algorithm %v", stateResAlgo)
}
// Return the final resolved state events, including both the
// resolved set of conflicted events, and the unconflicted events.
return resolved, nil
}
func (v *StateResolution) resolveConflicts( func (v *StateResolution) resolveConflicts(
ctx context.Context, version gomatrixserverlib.RoomVersion, ctx context.Context, version gomatrixserverlib.RoomVersion,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,
@ -843,7 +770,6 @@ func (v *StateResolution) resolveConflictsV1(
// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. // Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
// The returned list is sorted by state key tuple. // The returned list is sorted by state key tuple.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
// nolint:gocyclo
func (v *StateResolution) resolveConflictsV2( func (v *StateResolution) resolveConflictsV2(
ctx context.Context, ctx context.Context,
notConflicted, conflicted []types.StateEntry, notConflicted, conflicted []types.StateEntry,

View file

@ -412,7 +412,6 @@ func (d *Database) GetLatestEventsForUpdate(
return updater, err return updater, err
} }
// nolint:gocyclo
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool,
@ -672,7 +671,6 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
// to cross-reference with other tables when loading. // to cross-reference with other tables when loading.
// //
// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. // Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction.
// nolint:gocyclo
func (d *Database) handleRedactions( func (d *Database) handleRedactions(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*gomatrixserverlib.Event, string, error) { ) (*gomatrixserverlib.Event, string, error) {
@ -802,7 +800,6 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
// GetStateEvent returns the current state event of a given type for a given room with a given state key // GetStateEvent returns the current state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error
// nolint:gocyclo
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
roomInfo, err := d.RoomInfo(ctx, roomID) roomInfo, err := d.RoomInfo(ctx, roomID)
if err != nil { if err != nil {
@ -893,7 +890,6 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
// nolint:gocyclo
func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {
eventTypes := make([]string, 0, len(tuples)) eventTypes := make([]string, 0, len(tuples))
for _, tuple := range tuples { for _, tuple := range tuples {

View file

@ -316,7 +316,6 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli
// SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on // SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on
// ApiMux under /api/ and adds a prometheus handler under /metrics. // ApiMux under /api/ and adds a prometheus handler under /metrics.
// nolint:gocyclo
func (b *BaseDendrite) SetupAndServeHTTP( func (b *BaseDendrite) SetupAndServeHTTP(
internalHTTPAddr, externalHTTPAddr config.HTTPAddress, internalHTTPAddr, externalHTTPAddr config.HTTPAddress,
certFile, keyFile *string, certFile, keyFile *string,

View file

@ -238,7 +238,6 @@ func federatedEventRelationship(
} }
} }
// nolint:gocyclo
func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) { func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) {
var res gomatrixserverlib.MSC2836EventRelationshipsResponse var res gomatrixserverlib.MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent var returnEvents []*gomatrixserverlib.HeaderedEvent

View file

@ -46,7 +46,7 @@ const (
// Defaults sets the request defaults // Defaults sets the request defaults
func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) { func Defaults(r *gomatrixserverlib.MSC2946SpacesRequest) {
r.Limit = 100 r.Limit = 2000
r.MaxRoomsPerSpace = -1 r.MaxRoomsPerSpace = -1
} }
@ -70,11 +70,11 @@ func Enable(
} }
}) })
base.PublicClientAPIMux.Handle("/unstable/rooms/{roomID}/spaces", base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/spaces",
httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)), httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(db, rsAPI, fsAPI, base.Cfg.Global.ServerName)),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/spaces/{roomID}", httputil.MakeExternalAPI( base.PublicFederationAPIMux.Handle("/unstable/org.matrix.msc2946/spaces/{roomID}", httputil.MakeExternalAPI(
"msc2946_fed_spaces", func(req *http.Request) util.JSONResponse { "msc2946_fed_spaces", func(req *http.Request) util.JSONResponse {
fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
req, time.Now(), base.Cfg.Global.ServerName, keyRing, req, time.Now(), base.Cfg.Global.ServerName, keyRing,
@ -108,9 +108,6 @@ func federatedSpacesHandler(
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
} }
} }
if r.Limit > 100 {
r.Limit = 100
}
w := walker{ w := walker{
req: &r, req: &r,
rootRoomID: roomID, rootRoomID: roomID,
@ -147,9 +144,6 @@ func spacesHandler(
if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil { if resErr := chttputil.UnmarshalJSONRequest(req, &r); resErr != nil {
return *resErr return *resErr
} }
if r.Limit > 100 {
r.Limit = 100
}
w := walker{ w := walker{
req: &r, req: &r,
rootRoomID: roomID, rootRoomID: roomID,
@ -223,7 +217,6 @@ func (w *walker) markSent(id string) {
w.inMemoryBatchCache[w.callerID()] = m w.inMemoryBatchCache[w.callerID()] = m
} }
// nolint:gocyclo
func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse { func (w *walker) walk() *gomatrixserverlib.MSC2946SpacesResponse {
var res gomatrixserverlib.MSC2946SpacesResponse var res gomatrixserverlib.MSC2946SpacesResponse
// Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms

View file

@ -309,7 +309,7 @@ func postSpaces(t *testing.T, expectCode int, accessToken, roomID string, req *g
t.Fatalf("failed to marshal request: %s", err) t.Fatalf("failed to marshal request: %s", err)
} }
httpReq, err := http.NewRequest( httpReq, err := http.NewRequest(
"POST", "http://localhost:8010/_matrix/client/unstable/rooms/"+url.PathEscape(roomID)+"/spaces", "POST", "http://localhost:8010/_matrix/client/unstable/org.matrix.msc2946/rooms/"+url.PathEscape(roomID)+"/spaces",
bytes.NewBuffer(data), bytes.NewBuffer(data),
) )
httpReq.Header.Set("Authorization", "Bearer "+accessToken) httpReq.Header.Set("Authorization", "Bearer "+accessToken)

View file

@ -46,7 +46,6 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID,
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information. // be already filled in with join/leave information.
// nolint:gocyclo
func DeviceListCatchup( func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
userID string, res *types.Response, from, to types.LogPosition, userID string, res *types.Response, from, to types.LogPosition,
@ -137,7 +136,6 @@ func DeviceListCatchup(
} }
// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. // TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response.
// nolint:gocyclo
func TrackChangedUsers( func TrackChangedUsers(
ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string, ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string,
) (changed, left []string, err error) { ) (changed, left []string, err error) {

View file

@ -61,7 +61,6 @@ const defaultMessagesLimit = 10
// OnIncomingMessagesRequest implements the /messages endpoint from the // OnIncomingMessagesRequest implements the /messages endpoint from the
// client-server API. // client-server API.
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
// nolint:gocyclo
func OnIncomingMessagesRequest( func OnIncomingMessagesRequest(
req *http.Request, db storage.Database, roomID string, device *userapi.Device, req *http.Request, db storage.Database, roomID string, device *userapi.Device,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
@ -306,7 +305,6 @@ func (r *messagesReq) retrieveEvents() (
return clientEvents, start, end, err return clientEvents, start, end, err
} }
// nolint:gocyclo
func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
// user shouldn't see, we check the recent events and remove any prior to the join event of the user // user shouldn't see, we check the recent events and remove any prior to the join event of the user

View file

@ -35,7 +35,7 @@ type Database interface {
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)

View file

@ -53,7 +53,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_account_data_id_idx ON syncapi_account
const insertAccountDataSQL = "" + const insertAccountDataSQL = "" +
"INSERT INTO syncapi_account_data_type (user_id, room_id, type) VALUES ($1, $2, $3)" + "INSERT INTO syncapi_account_data_type (user_id, room_id, type) VALUES ($1, $2, $3)" +
" ON CONFLICT ON CONSTRAINT syncapi_account_data_unique" + " ON CONFLICT ON CONSTRAINT syncapi_account_data_unique" +
" DO UPDATE SET id = EXCLUDED.id" + " DO UPDATE SET id = nextval('syncapi_stream_id')" +
" RETURNING id" " RETURNING id"
const selectAccountDataInRangeSQL = "" + const selectAccountDataInRangeSQL = "" +

View file

@ -84,7 +84,8 @@ const selectCurrentStateSQL = "" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
" AND ( $6::bool IS NULL OR contains_url = $6 )" + " AND ( $6::bool IS NULL OR contains_url = $6 )" +
" LIMIT $7" " AND (event_id = ANY($7)) IS NOT TRUE" +
" LIMIT $8"
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
@ -197,6 +198,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
rows, err := stmt.QueryContext(ctx, roomID, rows, err := stmt.QueryContext(ctx, roomID,
@ -205,6 +207,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL, stateFilter.ContainsURL,
pq.StringArray(excludeEventIDs),
stateFilter.Limit, stateFilter.Limit,
) )
if err != nil { if err != nil {

View file

@ -75,7 +75,7 @@ const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
"room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) " + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) " +
"ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = $11 " + "ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " +
"RETURNING id" "RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +

View file

@ -36,7 +36,6 @@ type SyncServerDatasource struct {
} }
// NewDatabase creates a new sync server database // NewDatabase creates a new sync server database
// nolint:gocyclo
func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) {
var d SyncServerDatasource var d SyncServerDatasource
var err error var err error

View file

@ -103,8 +103,8 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart) return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs)
} }
func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
@ -195,7 +195,7 @@ func (d *Database) GetStateEvent(
func (d *Database) GetStateEventsForRoom( func (d *Database) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { ) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter) stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter, nil)
return return
} }
@ -661,7 +661,6 @@ func (d *Database) fetchMissingStateEvents(
// exclusive of oldPos, inclusive of newPos, for the rooms in which // exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events. // the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
// nolint:gocyclo
func (d *Database) GetStateDeltas( func (d *Database) GetStateDeltas(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
@ -773,7 +772,6 @@ func (d *Database) GetStateDeltas(
// requests with full_state=true. // requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get // Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms. // updates for other rooms.
// nolint:gocyclo
func (d *Database) GetStateDeltasForFullStateSync( func (d *Database) GetStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device, ctx context.Context, device *userapi.Device,
r types.Range, userID string, r types.Range, userID string,
@ -870,7 +868,7 @@ func (d *Database) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS syncapi_account_data_type (
const insertAccountDataSQL = "" + const insertAccountDataSQL = "" +
"INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" + "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id, room_id, type) DO UPDATE" + " ON CONFLICT (user_id, room_id, type) DO UPDATE" +
" SET id = EXCLUDED.id" " SET id = $5"
const selectAccountDataInRangeSQL = "" + const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" + "SELECT room_id, type FROM syncapi_account_data_type" +
@ -86,7 +86,7 @@ func (s *accountDataStatements) InsertAccountData(
if err != nil { if err != nil {
return return
} }
_, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos)
return return
} }

View file

@ -178,6 +178,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) { ) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt, params, err := prepareWithFilters( stmt, params, err := prepareWithFilters(
s.db, txn, selectCurrentStateSQL, s.db, txn, selectCurrentStateSQL,
@ -186,7 +187,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
}, },
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
stateFilter.Limit, FilterOrderNone, excludeEventIDs, stateFilter.Limit, FilterOrderNone,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -25,7 +25,7 @@ const (
// parts. // parts.
func prepareWithFilters( func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{}, db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, senders, notsenders, types, nottypes []string, excludeEventIDs []string,
limit int, order FilterOrder, limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) { ) (*sql.Stmt, []interface{}, error) {
offset := len(params) offset := len(params)
@ -53,6 +53,12 @@ func prepareWithFilters(
params, offset = append(params, v), offset+1 params, offset = append(params, v), offset+1
} }
} }
if count := len(excludeEventIDs); count > 0 {
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range excludeEventIDs {
params, offset = append(params, v), offset+1
}
}
switch order { switch order {
case FilterOrderAsc: case FilterOrderAsc:
query += " ORDER BY id ASC" query += " ORDER BY id ASC"

View file

@ -54,7 +54,7 @@ const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
"id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $13" "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
@ -150,7 +150,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
}, },
stateFilter.Senders, stateFilter.NotSenders, stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes, stateFilter.Types, stateFilter.NotTypes,
stateFilter.Limit, FilterOrderAsc, nil, stateFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -326,7 +326,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit+1, FilterOrderDesc, nil, eventFilter.Limit+1, FilterOrderDesc,
) )
if err != nil { if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
@ -374,7 +374,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
}, },
eventFilter.Senders, eventFilter.NotSenders, eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes, eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit, FilterOrderAsc, nil, eventFilter.Limit, FilterOrderAsc,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err) return nil, fmt.Errorf("s.prepareWithFilters: %w", err)

View file

@ -52,7 +52,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
return &d, nil return &d, nil
} }
// nolint:gocyclo
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err return err

View file

@ -91,7 +91,7 @@ type CurrentRoomState interface {
DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error
DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error
// SelectCurrentState returns all the current state events for the given room. // SelectCurrentState returns all the current state events for the given room.
SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.

View file

@ -98,7 +98,7 @@ func (p *PDUStreamProvider) CompleteSync(
var jr *types.JoinResponse var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync( jr, err = p.getJoinResponseForCompleteSync(
ctx, roomID, r, &stateFilter, &eventFilter, req.Device, ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
) )
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -124,7 +124,7 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted { if !peek.Deleted {
var jr *types.JoinResponse var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync( jr, err = p.getJoinResponseForCompleteSync(
ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.Device, ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
) )
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@ -137,7 +137,6 @@ func (p *PDUStreamProvider) CompleteSync(
return to return to
} }
// nolint:gocyclo
func (p *PDUStreamProvider) IncrementalSync( func (p *PDUStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
@ -260,20 +259,30 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
r types.Range, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
eventFilter *gomatrixserverlib.RoomEventFilter, eventFilter *gomatrixserverlib.RoomEventFilter,
wantFullState bool,
device *userapi.Device, device *userapi.Device,
) (jr *types.JoinResponse, err error) { ) (jr *types.JoinResponse, err error) {
var stateEvents []*gomatrixserverlib.HeaderedEvent // TODO: When filters are added, we may need to call this multiple times to get enough events.
stateEvents, err = p.DB.CurrentState(ctx, roomID, stateFilter) // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
recentStreamEvents, limited, err := p.DB.RecentEvents(
ctx, roomID, r, eventFilter, true, true,
)
if err != nil { if err != nil {
return return
} }
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // Get the event IDs of the stream events we fetched. There's no point in us
var recentStreamEvents []types.StreamEvent var excludingEventIDs []string
var limited bool if !wantFullState {
recentStreamEvents, limited, err = p.DB.RecentEvents( excludingEventIDs = make([]string, 0, len(recentStreamEvents))
ctx, roomID, r, eventFilter, true, true, for _, event := range recentStreamEvents {
) if event.StateKey() != nil {
excludingEventIDs = append(excludingEventIDs, event.EventID())
}
}
}
stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
if err != nil { if err != nil {
return return
} }

View file

@ -67,3 +67,6 @@ Forgotten room messages cannot be paginated
# Blacklisted due to flakiness # Blacklisted due to flakiness
Can re-join room if re-invited Can re-join room if re-invited
# Blacklisted due to flakiness after #1774
Local device key changes get to remote servers with correct prev_id

View file

@ -143,7 +143,6 @@ Local new device changes appear in v2 /sync
Local update device changes appear in v2 /sync Local update device changes appear in v2 /sync
Get left notifs for other users in sync and /keys/changes when user leaves Get left notifs for other users in sync and /keys/changes when user leaves
Local device key changes get to remote servers Local device key changes get to remote servers
Local device key changes get to remote servers with correct prev_id
Server correctly handles incoming m.device_list_update Server correctly handles incoming m.device_list_update
If remote user leaves room, changes device and rejoins we see update in sync If remote user leaves room, changes device and rejoins we see update in sync
If remote user leaves room, changes device and rejoins we see update in /keys/changes If remote user leaves room, changes device and rejoins we see update in /keys/changes

View file

@ -161,6 +161,7 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er
var uploadRes keyapi.PerformUploadKeysResponse var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: userID,
DeviceKeys: deviceKeys, DeviceKeys: deviceKeys,
}, &uploadRes) }, &uploadRes)
if uploadRes.Error != nil { if uploadRes.Error != nil {
@ -217,6 +218,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
// display name has changed: update the device key // display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: req.RequestingUserID,
DeviceKeys: []keyapi.DeviceKeys{ DeviceKeys: []keyapi.DeviceKeys{
{ {
DeviceID: dev.ID, DeviceID: dev.ID,

View file

@ -170,8 +170,8 @@ func (d *Database) CreateAccount(
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*api.Account, error) { ) (*api.Account, error) {
var account *api.Account
var err error var err error
// Generate a password hash if this is not a password-less user // Generate a password hash if this is not a password-less user
hash := "" hash := ""
if plaintextPassword != "" { if plaintextPassword != "" {
@ -180,14 +180,16 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) { if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
return nil, err return nil, err
} }
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -198,7 +200,7 @@ func (d *Database) createAccount(
}`)); err != nil { }`)); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) return account, nil
} }
// SaveAccountData saves new account data for a given user and a given room. // SaveAccountData saves new account data for a given user and a given room.

View file

@ -1,27 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !wasm
package sqlite3
import (
"errors"
"github.com/mattn/go-sqlite3"
)
func isConstraintError(err error) bool {
return errors.Is(err, sqlite3.ErrConstraint)
}

View file

@ -204,6 +204,7 @@ func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*api.Account, error) { ) (*api.Account, error) {
var err error var err error
var account *api.Account
// Generate a password hash if this is not a password-less user // Generate a password hash if this is not a password-less user
hash := "" hash := ""
if plaintextPassword != "" { if plaintextPassword != "" {
@ -212,14 +213,13 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
if isConstraintError(err) { return nil, sqlutil.ErrUserExists
return nil, sqlutil.ErrUserExists }
} if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
return nil, err return nil, err
} }
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": { "global": {
"content": [], "content": [],
"override": [], "override": [],
@ -230,7 +230,7 @@ func (d *Database) createAccount(
}`)); err != nil { }`)); err != nil {
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) return account, nil
} }
// SaveAccountData saves new account data for a given user and a given room. // SaveAccountData saves new account data for a given user and a given room.