mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Merge branch 'main' into s7evink/devicelistcleanup
This commit is contained in:
commit
94a092acd1
90
.github/workflows/dendrite.yml
vendored
90
.github/workflows/dendrite.yml
vendored
|
|
@ -26,22 +26,14 @@ jobs:
|
|||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.18
|
||||
|
||||
- uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-wasm
|
||||
cache: true
|
||||
|
||||
- name: Install Node
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 14
|
||||
|
||||
- uses: actions/cache@v2
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.npm
|
||||
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
|
||||
|
|
@ -109,19 +101,12 @@ jobs:
|
|||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
cache: true
|
||||
- name: Set up gotestfmt
|
||||
uses: gotesttools/gotestfmt-action@v2
|
||||
with:
|
||||
# Optional: pass GITHUB_TOKEN to avoid rate limiting.
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go${{ matrix.go }}-test-
|
||||
- run: go test -json -v ./... 2>&1 | gotestfmt
|
||||
env:
|
||||
POSTGRES_HOST: localhost
|
||||
|
|
@ -146,17 +131,17 @@ jobs:
|
|||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- name: Install dependencies x86
|
||||
if: ${{ matrix.goarch == '386' }}
|
||||
run: sudo apt update && sudo apt-get install -y gcc-multilib
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
|
||||
- name: Install dependencies x86
|
||||
if: ${{ matrix.goarch == '386' }}
|
||||
run: sudo apt update && sudo apt-get install -y gcc-multilib
|
||||
- env:
|
||||
GOOS: ${{ matrix.goos }}
|
||||
GOARCH: ${{ matrix.goarch }}
|
||||
|
|
@ -180,16 +165,16 @@ jobs:
|
|||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }}
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}
|
||||
key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc
|
||||
- env:
|
||||
GOOS: ${{ matrix.goos }}
|
||||
GOARCH: ${{ matrix.goarch }}
|
||||
|
|
@ -221,18 +206,13 @@ jobs:
|
|||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: "1.18"
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-upgrade
|
||||
cache: true
|
||||
- name: Build upgrade-tests
|
||||
run: go build ./cmd/dendrite-upgrade-tests
|
||||
- name: Test upgrade
|
||||
- name: Test upgrade (PostgreSQL)
|
||||
run: ./dendrite-upgrade-tests --head .
|
||||
- name: Test upgrade (SQLite)
|
||||
run: ./dendrite-upgrade-tests --sqlite --head .
|
||||
|
||||
# run database upgrade tests, skipping over one version
|
||||
upgrade_test_direct:
|
||||
|
|
@ -246,17 +226,12 @@ jobs:
|
|||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: "1.18"
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-upgrade
|
||||
cache: true
|
||||
- name: Build upgrade-tests
|
||||
run: go build ./cmd/dendrite-upgrade-tests
|
||||
- name: Test upgrade
|
||||
- name: Test upgrade (PostgreSQL)
|
||||
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
|
||||
- name: Test upgrade (SQLite)
|
||||
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
|
||||
|
||||
# run Sytest in different variations
|
||||
|
|
@ -291,6 +266,8 @@ jobs:
|
|||
image: matrixdotorg/sytest-dendrite:latest
|
||||
volumes:
|
||||
- ${{ github.workspace }}:/src
|
||||
- /root/.cache/go-build:/github/home/.cache/go-build
|
||||
- /root/.cache/go-mod:/gopath/pkg/mod
|
||||
env:
|
||||
POSTGRES: ${{ matrix.postgres && 1}}
|
||||
API: ${{ matrix.api && 1 }}
|
||||
|
|
@ -298,6 +275,14 @@ jobs:
|
|||
CGO_ENABLED: ${{ matrix.cgo && 1 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
~/.cache/go-build
|
||||
/gopath/pkg/mod
|
||||
key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-sytest-
|
||||
- name: Run Sytest
|
||||
run: /bootstrap.sh dendrite
|
||||
working-directory: /src
|
||||
|
|
@ -332,12 +317,14 @@ jobs:
|
|||
matrix:
|
||||
include:
|
||||
- label: SQLite native
|
||||
cgo: 0
|
||||
|
||||
- label: SQLite Cgo
|
||||
cgo: 1
|
||||
|
||||
- label: SQLite native, full HTTP APIs
|
||||
api: full-http
|
||||
cgo: 0
|
||||
|
||||
- label: SQLite Cgo, full HTTP APIs
|
||||
api: full-http
|
||||
|
|
@ -345,10 +332,12 @@ jobs:
|
|||
|
||||
- label: PostgreSQL
|
||||
postgres: Postgres
|
||||
cgo: 0
|
||||
|
||||
- label: PostgreSQL, full HTTP APIs
|
||||
postgres: Postgres
|
||||
api: full-http
|
||||
cgo: 0
|
||||
steps:
|
||||
# Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement.
|
||||
# See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path
|
||||
|
|
@ -356,14 +345,12 @@ jobs:
|
|||
run: |
|
||||
echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH
|
||||
echo "~/go/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: "Install Complement Dependencies"
|
||||
# We don't need to install Go because it is included on the Ubuntu 20.04 image:
|
||||
# See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev
|
||||
go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest
|
||||
|
||||
- name: Run actions/checkout@v3 for dendrite
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
|
|
@ -389,12 +376,10 @@ jobs:
|
|||
if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
(wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break
|
||||
done
|
||||
|
||||
# Build initial Dendrite image
|
||||
- run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
|
||||
- run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile .
|
||||
working-directory: dendrite
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
|
|
@ -406,9 +391,8 @@ jobs:
|
|||
shell: bash
|
||||
name: Run Complement Tests
|
||||
env:
|
||||
COMPLEMENT_BASE_IMAGE: complement-dendrite:latest
|
||||
COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }}
|
||||
API: ${{ matrix.api && 1 }}
|
||||
CGO_ENABLED: ${{ matrix.cgo && 1 }}
|
||||
working-directory: complement
|
||||
|
||||
integration-tests-done:
|
||||
|
|
@ -434,7 +418,7 @@ jobs:
|
|||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
security-events: write # To upload Trivy sarif files
|
||||
security-events: write # To upload Trivy sarif files
|
||||
if: github.repository == 'matrix-org/dendrite' && github.ref_name == 'main'
|
||||
needs: [integration-tests-done]
|
||||
uses: matrix-org/dendrite/.github/workflows/docker.yml@main
|
||||
|
|
|
|||
8
.github/workflows/docker.yml
vendored
8
.github/workflows/docker.yml
vendored
|
|
@ -32,7 +32,7 @@ jobs:
|
|||
if: github.event_name == 'release' # Only for GitHub releases
|
||||
run: |
|
||||
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
|
||||
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
|
||||
[ ${BRANCH} == "main" ] && BRANCH=""
|
||||
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV
|
||||
|
|
@ -112,7 +112,7 @@ jobs:
|
|||
if: github.event_name == 'release' # Only for GitHub releases
|
||||
run: |
|
||||
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
|
||||
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
|
||||
[ ${BRANCH} == "main" ] && BRANCH=""
|
||||
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV
|
||||
|
|
@ -191,7 +191,7 @@ jobs:
|
|||
if: github.event_name == 'release' # Only for GitHub releases
|
||||
run: |
|
||||
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
|
||||
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
|
||||
[ ${BRANCH} == "main" ] && BRANCH=""
|
||||
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV
|
||||
|
|
@ -258,7 +258,7 @@ jobs:
|
|||
if: github.event_name == 'release' # Only for GitHub releases
|
||||
run: |
|
||||
echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV
|
||||
echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV
|
||||
BRANCH=$(git symbolic-ref --short HEAD | tr -d \/)
|
||||
[ ${BRANCH} == "main" ] && BRANCH=""
|
||||
echo "BRANCH=${BRANCH}" >> $GITHUB_ENV
|
||||
|
|
|
|||
17
CHANGES.md
17
CHANGES.md
|
|
@ -1,5 +1,22 @@
|
|||
# Changelog
|
||||
|
||||
## Dendrite 0.10.7 (2022-11-04)
|
||||
|
||||
### Features
|
||||
|
||||
* Dendrite will now use a native SQLite port when building with `CGO_ENABLED=0`
|
||||
* A number of `thirdparty` endpoints have been added, improving support for appservices
|
||||
|
||||
### Fixes
|
||||
|
||||
* The `"state"` section of the `/sync` response is no longer limited, so state events should not be dropped unexpectedly
|
||||
* The deduplication of the `"timeline"` and `"state"` sections in `/sync` is now performed after applying history visibility, so state events should not be dropped unexpectedly
|
||||
* The `prev_batch` token returned by `/sync` is now calculated after applying history visibility, so that the pagination boundaries are correct
|
||||
* The room summary membership counts in `/sync` should now be calculated properly in more cases
|
||||
* A false membership leave event should no longer be sent down `/sync` as a result of retiring an accepted invite (contributed by [tak-hntlabs](https://github.com/tak-hntlabs))
|
||||
* Presence updates are now only sent to other servers for which the user shares rooms
|
||||
* A bug which could cause a panic when converting events into the `ClientEvent` format has been fixed
|
||||
|
||||
## Dendrite 0.10.6 (2022-11-01)
|
||||
|
||||
### Features
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ ARG TARGETARCH
|
|||
ARG FLAGS
|
||||
RUN --mount=target=. \
|
||||
--mount=type=cache,target=/root/.cache/go-build \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
USERARCH=`go env GOARCH` \
|
||||
GOARCH="$TARGETARCH" \
|
||||
GOOS="linux" \
|
||||
|
|
@ -23,7 +24,7 @@ RUN --mount=target=. \
|
|||
go build -v -ldflags="${FLAGS}" -trimpath -o /out/ ./cmd/...
|
||||
|
||||
#
|
||||
# The dendrite base image; mainly creates a user and switches to it
|
||||
# The dendrite base image
|
||||
#
|
||||
FROM alpine:latest AS dendrite-base
|
||||
LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go"
|
||||
|
|
@ -31,8 +32,6 @@ LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite"
|
|||
LABEL org.opencontainers.image.licenses="Apache-2.0"
|
||||
LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/"
|
||||
LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C."
|
||||
RUN addgroup dendrite && adduser dendrite -G dendrite -u 1337 -D
|
||||
USER dendrite
|
||||
|
||||
#
|
||||
# Builds the polylith image and only contains the polylith binary
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// AddInternalRoutes registers HTTP handlers for internal API calls
|
||||
|
|
@ -74,7 +75,7 @@ func NewInternalAPI(
|
|||
// events to be sent out.
|
||||
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
||||
// Create bot account for this AS if it doesn't already exist
|
||||
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
|
||||
if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"appservice": appservice.ID,
|
||||
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
||||
|
|
@ -101,11 +102,13 @@ func NewInternalAPI(
|
|||
func generateAppServiceAccount(
|
||||
userAPI userapi.AppserviceUserAPI,
|
||||
as config.ApplicationService,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
var accRes userapi.PerformAccountCreationResponse
|
||||
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
|
||||
AccountType: userapi.AccountTypeAppService,
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AppServiceID: as.ID,
|
||||
OnConflict: userapi.ConflictUpdate,
|
||||
}, &accRes)
|
||||
|
|
@ -115,6 +118,7 @@ func generateAppServiceAccount(
|
|||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AccessToken: as.ASToken,
|
||||
DeviceID: &as.SenderLocalpart,
|
||||
DeviceDisplayName: &as.SenderLocalpart,
|
||||
|
|
|
|||
|
|
@ -10,12 +10,13 @@ RUN mkdir /dendrite
|
|||
|
||||
# Utilise Docker caching when downloading dependencies, this stops us needlessly
|
||||
# downloading dependencies every time.
|
||||
ARG CGO
|
||||
RUN --mount=target=. \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
--mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -o /dendrite ./cmd/generate-config && \
|
||||
go build -o /dendrite ./cmd/generate-keys && \
|
||||
go build -o /dendrite ./cmd/dendrite-monolith-server
|
||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
|
||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
|
||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
|
||||
|
||||
WORKDIR /dendrite
|
||||
RUN ./generate-keys --private-key matrix_key.pem
|
||||
|
|
|
|||
|
|
@ -28,12 +28,13 @@ RUN mkdir /dendrite
|
|||
|
||||
# Utilise Docker caching when downloading dependencies, this stops us needlessly
|
||||
# downloading dependencies every time.
|
||||
ARG CGO
|
||||
RUN --mount=target=. \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
--mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -o /dendrite ./cmd/generate-config && \
|
||||
go build -o /dendrite ./cmd/generate-keys && \
|
||||
go build -o /dendrite ./cmd/dendrite-monolith-server
|
||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
|
||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
|
||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
|
||||
|
||||
WORKDIR /dendrite
|
||||
RUN ./generate-keys --private-key matrix_key.pem
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
|
|||
|
||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
||||
r := req.(*PasswordRequest)
|
||||
username := strings.ToLower(r.Username())
|
||||
username := r.Username()
|
||||
if username == "" {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
|
|
@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
|
|||
JSON: jsonerror.BadJSON("A password must be supplied."),
|
||||
}
|
||||
}
|
||||
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
|
||||
localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||
}
|
||||
}
|
||||
if !t.Config.Matrix.IsLocalServerName(domain) {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: jsonerror.InvalidUsername("The server name is not known."),
|
||||
}
|
||||
}
|
||||
// Squash username to all lowercase letters
|
||||
res := &api.QueryAccountByPasswordResponse{}
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res)
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||
Localpart: strings.ToLower(localpart),
|
||||
ServerName: domain,
|
||||
PlaintextPassword: r.Password,
|
||||
}, res)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("unable to fetch account by password"),
|
||||
JSON: jsonerror.Unknown("Unable to fetch account by password."),
|
||||
}
|
||||
}
|
||||
|
||||
if !res.Exists {
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
PlaintextPassword: r.Password,
|
||||
}, res)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("unable to fetch account by password"),
|
||||
JSON: jsonerror.Unknown("Unable to fetch account by password."),
|
||||
}
|
||||
}
|
||||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
|||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
serverName := cfg.Matrix.ServerName
|
||||
localpart, ok := vars["localpart"]
|
||||
if !ok {
|
||||
return util.JSONResponse{
|
||||
|
|
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
|||
JSON: jsonerror.MissingArgument("Expecting user localpart."),
|
||||
}
|
||||
}
|
||||
if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil {
|
||||
localpart, serverName = l, s
|
||||
}
|
||||
request := struct {
|
||||
Password string `json:"password"`
|
||||
}{}
|
||||
|
|
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
|||
}
|
||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
Password: request.Password,
|
||||
LogoutDevices: true,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ func completeAuth(
|
|||
DeviceID: login.DeviceID,
|
||||
AccessToken: token,
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
IPAddr: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, &performRes)
|
||||
|
|
|
|||
|
|
@ -40,16 +40,17 @@ func GetNotifications(
|
|||
}
|
||||
|
||||
var queryRes userapi.QueryNotificationsResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
||||
Localpart: localpart,
|
||||
From: req.URL.Query().Get("from"),
|
||||
Limit: int(limit),
|
||||
Only: req.URL.Query().Get("only"),
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
From: req.URL.Query().Get("from"),
|
||||
Limit: int(limit),
|
||||
Only: req.URL.Query().Get("only"),
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ func Password(
|
|||
}
|
||||
|
||||
// Get the local part.
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
|
@ -94,8 +94,9 @@ func Password(
|
|||
|
||||
// Ask the user API to perform the password change.
|
||||
passwordReq := &api.PerformPasswordUpdateRequest{
|
||||
Localpart: localpart,
|
||||
Password: r.NewPassword,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
Password: r.NewPassword,
|
||||
}
|
||||
passwordRes := &api.PerformPasswordUpdateResponse{}
|
||||
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
|
||||
|
|
@ -122,8 +123,9 @@ func Password(
|
|||
}
|
||||
|
||||
pushersReq := &api.PerformPusherDeletionRequest{
|
||||
Localpart: localpart,
|
||||
SessionID: device.SessionID,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
SessionID: device.SessionID,
|
||||
}
|
||||
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||
|
|
|
|||
|
|
@ -31,13 +31,14 @@ func GetPushers(
|
|||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
var queryRes userapi.QueryPushersResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
||||
Localpart: localpart,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
||||
|
|
@ -59,7 +60,7 @@ func SetPusher(
|
|||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
|
@ -93,6 +94,7 @@ func SetPusher(
|
|||
|
||||
}
|
||||
body.Localpart = localpart
|
||||
body.ServerName = domain
|
||||
body.SessionID = device.SessionID
|
||||
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -588,12 +588,15 @@ func Register(
|
|||
}
|
||||
// Auto generate a numeric username if r.Username is empty
|
||||
if r.Username == "" {
|
||||
res := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
|
||||
nreq := &userapi.QueryNumericLocalpartRequest{
|
||||
ServerName: cfg.Matrix.ServerName, // TODO: might not be right
|
||||
}
|
||||
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
r.Username = strconv.FormatInt(res.ID, 10)
|
||||
r.Username = strconv.FormatInt(nres.ID, 10)
|
||||
}
|
||||
|
||||
// Is this an appservice registration? It will be if the access
|
||||
|
|
@ -676,6 +679,7 @@ func handleGuestRegistration(
|
|||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: res.Account.Localpart,
|
||||
ServerName: res.Account.ServerName,
|
||||
DeviceDisplayName: r.InitialDisplayName,
|
||||
AccessToken: token,
|
||||
IPAddr: req.RemoteAddr,
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}",
|
||||
dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
|
||||
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return AdminResetPassword(req, cfg, device, userAPI)
|
||||
}),
|
||||
|
|
@ -252,7 +252,7 @@ func Setup(
|
|||
return JoinRoomByIDOrAlias(
|
||||
req, device, rsAPI, userAPI, vars["roomIDOrAlias"],
|
||||
)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
if mscCfg.Enabled("msc2753") {
|
||||
|
|
@ -274,7 +274,7 @@ func Setup(
|
|||
v3mux.Handle("/joined_rooms",
|
||||
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetJoinedRooms(req, device, rsAPI)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/join",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -288,7 +288,7 @@ func Setup(
|
|||
return JoinRoomByIDOrAlias(
|
||||
req, device, rsAPI, userAPI, vars["roomID"],
|
||||
)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/leave",
|
||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -302,7 +302,7 @@ func Setup(
|
|||
return LeaveRoomByID(
|
||||
req, device, rsAPI, vars["roomID"],
|
||||
)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/unpeek",
|
||||
httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -361,7 +361,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -372,7 +372,7 @@ func Setup(
|
|||
txnID := vars["txnID"]
|
||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
|
||||
nil, cfg, rsAPI, transactionsCache)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -381,7 +381,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"])
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
|
@ -400,7 +400,7 @@ func Setup(
|
|||
eventType := strings.TrimSuffix(vars["type"], "/")
|
||||
eventFormat := req.URL.Query().Get("format") == "event"
|
||||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
|
@ -409,7 +409,7 @@ func Setup(
|
|||
}
|
||||
eventFormat := req.URL.Query().Get("format") == "event"
|
||||
return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
|
||||
httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -420,7 +420,7 @@ func Setup(
|
|||
emptyString := ""
|
||||
eventType := strings.TrimSuffix(vars["eventType"], "/")
|
||||
return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
|
||||
|
|
@ -431,7 +431,7 @@ func Setup(
|
|||
}
|
||||
stateKey := vars["stateKey"]
|
||||
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||
|
|
@ -575,7 +575,7 @@ func Setup(
|
|||
}
|
||||
txnID := vars["txnID"]
|
||||
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
// This is only here because sytest refers to /unstable for this endpoint
|
||||
|
|
@ -589,7 +589,7 @@ func Setup(
|
|||
}
|
||||
txnID := vars["txnID"]
|
||||
return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/account/whoami",
|
||||
|
|
@ -598,7 +598,7 @@ func Setup(
|
|||
return *r
|
||||
}
|
||||
return Whoami(req, device)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/account/password",
|
||||
|
|
@ -830,7 +830,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
// Browsers use the OPTIONS HTTP method to check if the CORS policy allows
|
||||
// PUT requests, so we need to allow this method
|
||||
|
|
@ -871,7 +871,7 @@ func Setup(
|
|||
v3mux.Handle("/thirdparty/protocols",
|
||||
httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return Protocols(req, asAPI, device, "")
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/thirdparty/protocol/{protocolID}",
|
||||
|
|
@ -881,7 +881,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return Protocols(req, asAPI, device, vars["protocolID"])
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/thirdparty/user/{protocolID}",
|
||||
|
|
@ -891,13 +891,13 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return User(req, asAPI, device, vars["protocolID"], req.URL.Query())
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/thirdparty/user",
|
||||
httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return User(req, asAPI, device, "", req.URL.Query())
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/thirdparty/location/{protocolID}",
|
||||
|
|
@ -907,13 +907,13 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return Location(req, asAPI, device, vars["protocolID"], req.URL.Query())
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/thirdparty/location",
|
||||
httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return Location(req, asAPI, device, "", req.URL.Query())
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/initialSync",
|
||||
|
|
@ -1054,7 +1054,7 @@ func Setup(
|
|||
v3mux.Handle("/devices",
|
||||
httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return GetDevicesByLocalpart(req, userAPI, device)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/devices/{deviceID}",
|
||||
|
|
@ -1064,7 +1064,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetDeviceByID(req, userAPI, device, vars["deviceID"])
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/devices/{deviceID}",
|
||||
|
|
@ -1074,7 +1074,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return UpdateDeviceByID(req, userAPI, device, vars["deviceID"])
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/devices/{deviceID}",
|
||||
|
|
@ -1116,21 +1116,21 @@ func Setup(
|
|||
|
||||
// Stub implementations for sytest
|
||||
v3mux.Handle("/events",
|
||||
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
|
||||
"chunk": []interface{}{},
|
||||
"start": "",
|
||||
"end": "",
|
||||
}}
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/initialSync",
|
||||
httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
|
||||
httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
|
||||
"end": "",
|
||||
}}
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/user/{userId}/rooms/{roomId}/tags",
|
||||
|
|
@ -1169,7 +1169,7 @@ func Setup(
|
|||
return *r
|
||||
}
|
||||
return GetCapabilities(req, rsAPI)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// Key Backup Versions (Metadata)
|
||||
|
|
@ -1350,7 +1350,7 @@ func Setup(
|
|||
|
||||
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
|
||||
})
|
||||
}, httputil.WithAllowGuests())
|
||||
|
||||
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
|
@ -1362,22 +1362,22 @@ func Setup(
|
|||
v3mux.Handle("/keys/upload/{deviceID}",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadKeys(req, keyAPI, device)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/upload",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadKeys(req, keyAPI, device)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/query",
|
||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return QueryKeys(req, keyAPI, device)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/claim",
|
||||
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return ClaimKeys(req, keyAPI)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
|
|||
|
|
@ -286,6 +286,7 @@ func getSenderDevice(
|
|||
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
||||
AccountType: userapi.AccountTypeUser,
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
OnConflict: userapi.ConflictUpdate,
|
||||
}, &accRes)
|
||||
if err != nil {
|
||||
|
|
@ -295,8 +296,9 @@ func getSenderDevice(
|
|||
// Set the avatarurl for the user
|
||||
avatarRes := &userapi.PerformSetAvatarURLResponse{}
|
||||
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
|
||||
}, avatarRes); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
|
||||
return nil, err
|
||||
|
|
@ -308,6 +310,7 @@ func getSenderDevice(
|
|||
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
|
||||
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
|
||||
}, displayNameRes); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
|
||||
|
|
@ -353,6 +356,7 @@ func getSenderDevice(
|
|||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
|
||||
AccessToken: token,
|
||||
NoDeviceListUpdate: true,
|
||||
|
|
|
|||
|
|
@ -136,16 +136,17 @@ func CheckAndSave3PIDAssociation(
|
|||
}
|
||||
|
||||
// Save the association in the database
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
|
||||
ThreePID: address,
|
||||
Localpart: localpart,
|
||||
Medium: medium,
|
||||
ThreePID: address,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
Medium: medium,
|
||||
}, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
|
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
|
|||
func GetAssociated3PIDs(
|
||||
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
|
@ -169,7 +170,8 @@ func GetAssociated3PIDs(
|
|||
|
||||
res := &api.QueryThreePIDsForLocalpartResponse{}
|
||||
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
|
||||
Localpart: localpart,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
}, res)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ var (
|
|||
flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github")
|
||||
flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.")
|
||||
flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done")
|
||||
flagSqlite = flag.Bool("sqlite", false, "Test SQLite instead of PostgreSQL")
|
||||
alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+")
|
||||
)
|
||||
|
||||
|
|
@ -49,7 +50,7 @@ const HEAD = "HEAD"
|
|||
// due to the error:
|
||||
// When using COPY with more than one source file, the destination must be a directory and end with a /
|
||||
// We need to run a postgres anyway, so use the dockerfile associated with Complement instead.
|
||||
const Dockerfile = `FROM golang:1.18-stretch as build
|
||||
const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build
|
||||
RUN apt-get update && apt-get install -y postgresql
|
||||
WORKDIR /build
|
||||
|
||||
|
|
@ -92,6 +93,42 @@ ENV SERVER_NAME=localhost
|
|||
EXPOSE 8008 8448
|
||||
CMD /build/run_dendrite.sh `
|
||||
|
||||
const DockerfileSQLite = `FROM golang:1.18-stretch as build
|
||||
RUN apt-get update && apt-get install -y postgresql
|
||||
WORKDIR /build
|
||||
|
||||
# Copy the build context to the repo as this is the right dendrite code. This is different to the
|
||||
# Complement Dockerfile which wgets a branch.
|
||||
COPY . .
|
||||
|
||||
RUN go build ./cmd/dendrite-monolith-server
|
||||
RUN go build ./cmd/generate-keys
|
||||
RUN go build ./cmd/generate-config
|
||||
RUN ./generate-config --ci > dendrite.yaml
|
||||
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
|
||||
|
||||
# Make sure the SQLite databases are in a persistent location, we're already mapping
|
||||
# the postgresql folder so let's just use that for simplicity
|
||||
RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/9.6\/main\/%g" dendrite.yaml
|
||||
|
||||
# This entry script starts postgres, waits for it to be up then starts dendrite
|
||||
RUN echo '\
|
||||
sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\
|
||||
PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\
|
||||
./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\
|
||||
' > run_dendrite.sh && chmod +x run_dendrite.sh
|
||||
|
||||
ENV SERVER_NAME=localhost
|
||||
EXPOSE 8008 8448
|
||||
CMD /build/run_dendrite.sh `
|
||||
|
||||
func dockerfile() []byte {
|
||||
if *flagSqlite {
|
||||
return []byte(DockerfileSQLite)
|
||||
}
|
||||
return []byte(DockerfilePostgreSQL)
|
||||
}
|
||||
|
||||
const dendriteUpgradeTestLabel = "dendrite_upgrade_test"
|
||||
|
||||
// downloadArchive downloads an arbitrary github archive of the form:
|
||||
|
|
@ -150,7 +187,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
|
|||
if branchOrTagName == HEAD && *flagHead != "" {
|
||||
log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead)
|
||||
// add top level Dockerfile
|
||||
err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm)
|
||||
err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), dockerfile(), os.ModePerm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err)
|
||||
}
|
||||
|
|
@ -166,7 +203,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
|
|||
// pull an archive, this contains a top-level directory which screws with the build context
|
||||
// which we need to fix up post download
|
||||
u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName)
|
||||
tarball, err = downloadArchive(httpClient, tmpDir, u, []byte(Dockerfile))
|
||||
tarball, err = downloadArchive(httpClient, tmpDir, u, dockerfile())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to download archive %s: %w", u, err)
|
||||
}
|
||||
|
|
@ -367,7 +404,8 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string)
|
|||
// hit /versions to check it is up
|
||||
var lastErr error
|
||||
for i := 0; i < 500; i++ {
|
||||
res, err := http.Get(versionsURL)
|
||||
var res *http.Response
|
||||
res, err = http.Get(versionsURL)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
|
@ -381,18 +419,22 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string)
|
|||
lastErr = nil
|
||||
break
|
||||
}
|
||||
if lastErr != nil {
|
||||
logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{
|
||||
ShowStdout: true,
|
||||
ShowStderr: true,
|
||||
})
|
||||
// ignore errors when cannot get logs, it's just for debugging anyways
|
||||
if err == nil {
|
||||
logbody, err := io.ReadAll(logs)
|
||||
if err == nil {
|
||||
log.Printf("Container logs:\n\n%s\n\n", string(logbody))
|
||||
logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{
|
||||
ShowStdout: true,
|
||||
ShowStderr: true,
|
||||
Follow: true,
|
||||
})
|
||||
// ignore errors when cannot get logs, it's just for debugging anyways
|
||||
if err == nil {
|
||||
go func() {
|
||||
for {
|
||||
if body, err := io.ReadAll(logs); err == nil && len(body) > 0 {
|
||||
log.Printf("%s: %s", version, string(body))
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return baseURL, containerID, lastErr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
fedTypes "github.com/matrix-org/dendrite/federationapi/types"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
|
|
@ -39,6 +40,7 @@ type OutputPresenceConsumer struct {
|
|||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
isLocalServerName func(gomatrixserverlib.ServerName) bool
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI
|
||||
topic string
|
||||
outboundPresenceEnabled bool
|
||||
}
|
||||
|
|
@ -50,6 +52,7 @@ func NewOutputPresenceConsumer(
|
|||
js nats.JetStreamContext,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
rsAPI roomserverAPI.FederationRoomserverAPI,
|
||||
) *OutputPresenceConsumer {
|
||||
return &OutputPresenceConsumer{
|
||||
ctx: process.Context(),
|
||||
|
|
@ -60,6 +63,7 @@ func NewOutputPresenceConsumer(
|
|||
durable: cfg.Matrix.JetStream.Durable("FederationAPIPresenceConsumer"),
|
||||
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent),
|
||||
outboundPresenceEnabled: cfg.Matrix.Presence.EnableOutbound,
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,6 +93,16 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
|
|||
return true
|
||||
}
|
||||
|
||||
var queryRes roomserverAPI.QueryRoomsForUserResponse
|
||||
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{
|
||||
UserID: userID,
|
||||
WantMembership: "join",
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("failed to calculate joined rooms for user")
|
||||
return true
|
||||
}
|
||||
|
||||
presence := msg.Header.Get("presence")
|
||||
|
||||
ts, err := strconv.Atoi(msg.Header.Get("last_active_ts"))
|
||||
|
|
@ -96,11 +110,13 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
|
|||
return true
|
||||
}
|
||||
|
||||
joined, err := t.db.GetAllJoinedHosts(ctx)
|
||||
// send this presence to all servers who share rooms with this user.
|
||||
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("failed to get joined hosts")
|
||||
return true
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
syncAPITypes "github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
|
|
@ -34,14 +38,16 @@ import (
|
|||
|
||||
// OutputRoomEventConsumer consumes events that originated in the room server.
|
||||
type OutputRoomEventConsumer struct {
|
||||
ctx context.Context
|
||||
cfg *config.FederationAPI
|
||||
rsAPI api.FederationRoomserverAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
topic string
|
||||
ctx context.Context
|
||||
cfg *config.FederationAPI
|
||||
rsAPI api.FederationRoomserverAPI
|
||||
jetstream nats.JetStreamContext
|
||||
natsClient *nats.Conn
|
||||
durable string
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
topic string
|
||||
topicPresence string
|
||||
}
|
||||
|
||||
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
|
||||
|
|
@ -49,19 +55,22 @@ func NewOutputRoomEventConsumer(
|
|||
process *process.ProcessContext,
|
||||
cfg *config.FederationAPI,
|
||||
js nats.JetStreamContext,
|
||||
natsClient *nats.Conn,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
rsAPI api.FederationRoomserverAPI,
|
||||
) *OutputRoomEventConsumer {
|
||||
return &OutputRoomEventConsumer{
|
||||
ctx: process.Context(),
|
||||
cfg: cfg,
|
||||
jetstream: js,
|
||||
db: store,
|
||||
queues: queues,
|
||||
rsAPI: rsAPI,
|
||||
durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"),
|
||||
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
|
||||
ctx: process.Context(),
|
||||
cfg: cfg,
|
||||
jetstream: js,
|
||||
natsClient: natsClient,
|
||||
db: store,
|
||||
queues: queues,
|
||||
rsAPI: rsAPI,
|
||||
durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"),
|
||||
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
|
||||
topicPresence: cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -146,6 +155,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
|
|||
// processMessage updates the list of currently joined hosts in the room
|
||||
// and then sends the event to the hosts that were joined before the event.
|
||||
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
|
||||
|
||||
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
|
||||
|
||||
// Ask the roomserver and add in the rest of the results into the set.
|
||||
|
|
@ -184,6 +194,14 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
|
|||
return err
|
||||
}
|
||||
|
||||
// If we added new hosts, inform them about our known presence events for this room
|
||||
if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil {
|
||||
membership, _ := ore.Event.Membership()
|
||||
if membership == gomatrixserverlib.Join {
|
||||
s.sendPresence(ore.Event.RoomID(), addsJoinedHosts)
|
||||
}
|
||||
}
|
||||
|
||||
if oldJoinedHosts == nil {
|
||||
// This means that there is nothing to update as this is a duplicate
|
||||
// message.
|
||||
|
|
@ -213,6 +231,76 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
|
|||
)
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) {
|
||||
joined := make([]gomatrixserverlib.ServerName, len(addedJoined))
|
||||
for _, added := range addedJoined {
|
||||
joined = append(joined, added.ServerName)
|
||||
}
|
||||
|
||||
// get our locally joined users
|
||||
var queryRes api.QueryMembershipsForRoomResponse
|
||||
err := s.rsAPI.QueryMembershipsForRoom(s.ctx, &api.QueryMembershipsForRoomRequest{
|
||||
JoinedOnly: true,
|
||||
LocalOnly: true,
|
||||
RoomID: roomID,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("failed to calculate joined rooms for user")
|
||||
return
|
||||
}
|
||||
|
||||
// send every presence we know about to the remote server
|
||||
content := types.Presence{}
|
||||
for _, ev := range queryRes.JoinEvents {
|
||||
msg := nats.NewMsg(s.topicPresence)
|
||||
msg.Header.Set(jetstream.UserID, ev.Sender)
|
||||
|
||||
var presence *nats.Msg
|
||||
presence, err = s.natsClient.RequestMsg(msg, time.Second*10)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("unable to get presence")
|
||||
continue
|
||||
}
|
||||
|
||||
statusMsg := presence.Header.Get("status_msg")
|
||||
e := presence.Header.Get("error")
|
||||
if e != "" {
|
||||
continue
|
||||
}
|
||||
var lastActive int
|
||||
lastActive, err = strconv.Atoi(presence.Header.Get("last_active_ts"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)}
|
||||
|
||||
content.Push = append(content.Push, types.PresenceContent{
|
||||
CurrentlyActive: p.CurrentlyActive(),
|
||||
LastActiveAgo: p.LastActiveAgo(),
|
||||
Presence: presence.Header.Get("presence"),
|
||||
StatusMsg: &statusMsg,
|
||||
UserID: ev.Sender,
|
||||
})
|
||||
}
|
||||
|
||||
if len(content.Push) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MPresence,
|
||||
Origin: string(s.cfg.Matrix.ServerName),
|
||||
}
|
||||
if edu.Content, err = json.Marshal(content); err != nil {
|
||||
log.WithError(err).Error("failed to marshal EDU JSON")
|
||||
return
|
||||
}
|
||||
if err := s.queues.SendEDU(edu, s.cfg.Matrix.ServerName, joined); err != nil {
|
||||
log.WithError(err).Error("failed to send EDU")
|
||||
}
|
||||
}
|
||||
|
||||
// joinedHostsAtEvent works out a list of matrix servers that were joined to
|
||||
// the room at the event (including peeking ones)
|
||||
// It is important to use the state at the event for sending messages because:
|
||||
|
|
|
|||
|
|
@ -118,21 +118,29 @@ func NewInternalAPI(
|
|||
|
||||
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
|
||||
|
||||
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
|
||||
signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{}
|
||||
for _, serverName := range append(
|
||||
[]gomatrixserverlib.ServerName{base.Cfg.Global.ServerName},
|
||||
base.Cfg.Global.SecondaryServerNames...,
|
||||
) {
|
||||
signingInfo[serverName] = &queue.SigningInfo{
|
||||
KeyID: cfg.Matrix.KeyID,
|
||||
PrivateKey: cfg.Matrix.PrivateKey,
|
||||
ServerName: serverName,
|
||||
}
|
||||
}
|
||||
|
||||
queues := queue.NewOutgoingQueues(
|
||||
federationDB, base.ProcessContext,
|
||||
cfg.Matrix.DisableFederation,
|
||||
cfg.Matrix.ServerName, federation, rsAPI, &stats,
|
||||
&queue.SigningInfo{
|
||||
KeyID: cfg.Matrix.KeyID,
|
||||
PrivateKey: cfg.Matrix.PrivateKey,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
},
|
||||
signingInfo,
|
||||
)
|
||||
|
||||
rsConsumer := consumers.NewOutputRoomEventConsumer(
|
||||
base.ProcessContext, cfg, js, queues,
|
||||
base.ProcessContext, cfg, js, nats, queues,
|
||||
federationDB, rsAPI,
|
||||
)
|
||||
if err = rsConsumer.Start(); err != nil {
|
||||
|
|
@ -164,7 +172,7 @@ func NewInternalAPI(
|
|||
}
|
||||
|
||||
presenceConsumer := consumers.NewOutputPresenceConsumer(
|
||||
base.ProcessContext, cfg, js, queues, federationDB,
|
||||
base.ProcessContext, cfg, js, queues, federationDB, rsAPI,
|
||||
)
|
||||
if err = presenceConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start presence consumer")
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
|
|||
}
|
||||
|
||||
// Get the keys and JSON-ify them.
|
||||
keys := routing.LocalKeys(s.config)
|
||||
keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
|
||||
body, err := json.MarshalIndent(keys.JSON, "", " ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ type destinationQueue struct {
|
|||
queues *OutgoingQueues
|
||||
db storage.Database
|
||||
process *process.ProcessContext
|
||||
signing *SigningInfo
|
||||
signing map[gomatrixserverlib.ServerName]*SigningInfo
|
||||
rsAPI api.FederationRoomserverAPI
|
||||
client fedapi.FederationClient // federation client
|
||||
origin gomatrixserverlib.ServerName // origin of requests
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ type OutgoingQueues struct {
|
|||
origin gomatrixserverlib.ServerName
|
||||
client fedapi.FederationClient
|
||||
statistics *statistics.Statistics
|
||||
signing *SigningInfo
|
||||
signing map[gomatrixserverlib.ServerName]*SigningInfo
|
||||
queuesMutex sync.Mutex // protects the below
|
||||
queues map[gomatrixserverlib.ServerName]*destinationQueue
|
||||
}
|
||||
|
|
@ -91,7 +91,7 @@ func NewOutgoingQueues(
|
|||
client fedapi.FederationClient,
|
||||
rsAPI api.FederationRoomserverAPI,
|
||||
statistics *statistics.Statistics,
|
||||
signing *SigningInfo,
|
||||
signing map[gomatrixserverlib.ServerName]*SigningInfo,
|
||||
) *OutgoingQueues {
|
||||
queues := &OutgoingQueues{
|
||||
disabled: disabled,
|
||||
|
|
@ -199,11 +199,10 @@ func (oqs *OutgoingQueues) SendEvent(
|
|||
log.Trace("Federation is disabled, not sending event")
|
||||
return nil
|
||||
}
|
||||
if origin != oqs.origin {
|
||||
// TODO: Support virtual hosting; gh issue #577.
|
||||
if _, ok := oqs.signing[origin]; !ok {
|
||||
return fmt.Errorf(
|
||||
"sendevent: unexpected server to send as: got %q expected %q",
|
||||
origin, oqs.origin,
|
||||
"sendevent: unexpected server to send as %q",
|
||||
origin,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -214,7 +213,9 @@ func (oqs *OutgoingQueues) SendEvent(
|
|||
destmap[d] = struct{}{}
|
||||
}
|
||||
delete(destmap, oqs.origin)
|
||||
delete(destmap, oqs.signing.ServerName)
|
||||
for local := range oqs.signing {
|
||||
delete(destmap, local)
|
||||
}
|
||||
|
||||
// Check if any of the destinations are prohibited by server ACLs.
|
||||
for destination := range destmap {
|
||||
|
|
@ -288,11 +289,10 @@ func (oqs *OutgoingQueues) SendEDU(
|
|||
log.Trace("Federation is disabled, not sending EDU")
|
||||
return nil
|
||||
}
|
||||
if origin != oqs.origin {
|
||||
// TODO: Support virtual hosting; gh issue #577.
|
||||
if _, ok := oqs.signing[origin]; !ok {
|
||||
return fmt.Errorf(
|
||||
"sendevent: unexpected server to send as: got %q expected %q",
|
||||
origin, oqs.origin,
|
||||
"sendevent: unexpected server to send as %q",
|
||||
origin,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -303,7 +303,9 @@ func (oqs *OutgoingQueues) SendEDU(
|
|||
destmap[d] = struct{}{}
|
||||
}
|
||||
delete(destmap, oqs.origin)
|
||||
delete(destmap, oqs.signing.ServerName)
|
||||
for local := range oqs.signing {
|
||||
delete(destmap, local)
|
||||
}
|
||||
|
||||
// There is absolutely no guarantee that the EDU will have a room_id
|
||||
// field, as it is not required by the spec. However, if it *does*
|
||||
|
|
|
|||
|
|
@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
|
|||
}
|
||||
rs := &stubFederationRoomServerAPI{}
|
||||
stats := statistics.NewStatistics(db, failuresUntilBlacklist)
|
||||
signingInfo := &SigningInfo{
|
||||
KeyID: "ed21019:auto",
|
||||
PrivateKey: test.PrivateKeyA,
|
||||
ServerName: "localhost",
|
||||
signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{
|
||||
"localhost": {
|
||||
KeyID: "ed21019:auto",
|
||||
PrivateKey: test.PrivateKeyA,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
}
|
||||
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package routing
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
|
@ -134,18 +135,21 @@ func ClaimOneTimeKeys(
|
|||
|
||||
// LocalKeys returns the local keys for the server.
|
||||
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
||||
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse {
|
||||
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
||||
func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
|
||||
keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
|
||||
}
|
||||
|
||||
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
||||
func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
||||
var keys gomatrixserverlib.ServerKeys
|
||||
if !cfg.Matrix.IsLocalServerName(serverName) {
|
||||
return nil, fmt.Errorf("server name not known")
|
||||
}
|
||||
|
||||
keys.ServerName = cfg.Matrix.ServerName
|
||||
keys.ServerName = serverName
|
||||
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
|
||||
|
||||
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
||||
|
|
@ -172,7 +176,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
|
|||
}
|
||||
|
||||
keys.Raw, err = gomatrixserverlib.SignJSON(
|
||||
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
|
||||
string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -186,6 +190,14 @@ func NotaryKeys(
|
|||
fsAPI federationAPI.FederationInternalAPI,
|
||||
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
|
||||
) util.JSONResponse {
|
||||
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
|
||||
if !cfg.Matrix.IsLocalServerName(serverName) {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: jsonerror.NotFound("Server name not known"),
|
||||
}
|
||||
}
|
||||
|
||||
if req == nil {
|
||||
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
|
||||
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
|
||||
|
|
@ -201,7 +213,7 @@ func NotaryKeys(
|
|||
for serverName, kidToCriteria := range req.ServerKeys {
|
||||
var keyList []gomatrixserverlib.ServerKeys
|
||||
if serverName == cfg.Matrix.ServerName {
|
||||
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
|
||||
if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
|
||||
keyList = append(keyList, *k)
|
||||
} else {
|
||||
return util.ErrorResponse(err)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ func Setup(
|
|||
}
|
||||
|
||||
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
|
||||
return LocalKeys(cfg)
|
||||
return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
|
||||
})
|
||||
|
||||
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {
|
||||
|
|
|
|||
4
go.mod
4
go.mod
|
|
@ -8,7 +8,7 @@ require (
|
|||
github.com/Masterminds/semver/v3 v3.1.1
|
||||
github.com/blevesearch/bleve/v2 v2.3.4
|
||||
github.com/codeclysm/extract v2.2.0+incompatible
|
||||
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d
|
||||
github.com/dgraph-io/ristretto v0.1.1
|
||||
github.com/docker/docker v20.10.19+incompatible
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/getsentry/sentry-go v0.14.0
|
||||
|
|
@ -22,7 +22,7 @@ require (
|
|||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2
|
||||
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.15
|
||||
|
|
|
|||
10
go.sum
10
go.sum
|
|
@ -141,8 +141,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk=
|
||||
github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d/go.mod h1:RAy2GVV4sTWVlNMavv3xhLsk18rxhfhDnombTe6EF5c=
|
||||
github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8=
|
||||
github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA=
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
|
||||
github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68=
|
||||
|
|
@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
|||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e h1:6I34fdyiHMRCxL6GOb/G8ZyI1WWlb6ZxCF2hIGSMSCc=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 h1:Bet5n+//Yh+A2SuPHD67N8jrOhC/EIKvEgisfsKhTss=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
|
||||
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4=
|
||||
github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||
|
|
@ -526,7 +526,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
|||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ=
|
||||
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
|
|
@ -697,6 +696,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
|
|
|
|||
|
|
@ -42,10 +42,26 @@ type BasicAuth struct {
|
|||
Password string `yaml:"password"`
|
||||
}
|
||||
|
||||
type AuthAPIOpts struct {
|
||||
GuestAccessAllowed bool
|
||||
}
|
||||
|
||||
// AuthAPIOption is an option to MakeAuthAPI to add additional checks (e.g. guest access) to verify
|
||||
// the user is allowed to do specific things.
|
||||
type AuthAPIOption func(opts *AuthAPIOpts)
|
||||
|
||||
// WithAllowGuests checks that guest users have access to this endpoint
|
||||
func WithAllowGuests() AuthAPIOption {
|
||||
return func(opts *AuthAPIOpts) {
|
||||
opts.GuestAccessAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
|
||||
func MakeAuthAPI(
|
||||
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
|
||||
f func(*http.Request, *userapi.Device) util.JSONResponse,
|
||||
checks ...AuthAPIOption,
|
||||
) http.Handler {
|
||||
h := func(req *http.Request) util.JSONResponse {
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
|
@ -76,6 +92,19 @@ func MakeAuthAPI(
|
|||
}
|
||||
}()
|
||||
|
||||
// apply additional checks, if any
|
||||
opts := AuthAPIOpts{}
|
||||
for _, opt := range checks {
|
||||
opt(&opts)
|
||||
}
|
||||
|
||||
if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"),
|
||||
}
|
||||
}
|
||||
|
||||
jsonRes := f(req, device)
|
||||
// do not log 4xx as errors as they are client fails, not server fails
|
||||
if hub != nil && jsonRes.Code >= 500 {
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ var build string
|
|||
const (
|
||||
VersionMajor = 0
|
||||
VersionMinor = 10
|
||||
VersionPatch = 6
|
||||
VersionPatch = 7
|
||||
VersionTag = "" // example: "rc1"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
const defaultWaitTime = time.Minute
|
||||
const requestTimeout = time.Second * 30
|
||||
|
||||
func init() {
|
||||
|
|
@ -495,7 +494,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
|
|||
} else if e.Code >= 300 {
|
||||
// We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError
|
||||
// are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds.
|
||||
return time.Hour, err
|
||||
return hourWaitTime, err
|
||||
}
|
||||
case net.Error:
|
||||
// Use the default waitTime, if it's a timeout.
|
||||
|
|
@ -509,7 +508,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
|
|||
// This is to avoid spamming remote servers, which may not be Matrix servers anymore.
|
||||
if e.Code >= 300 {
|
||||
logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
|
||||
return time.Hour, err
|
||||
return hourWaitTime, err
|
||||
}
|
||||
default:
|
||||
// Something else failed
|
||||
|
|
|
|||
22
keyserver/internal/device_list_update_default.go
Normal file
22
keyserver/internal/device_list_update_default.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !vw
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
const defaultWaitTime = time.Minute
|
||||
const hourWaitTime = time.Hour
|
||||
25
keyserver/internal/device_list_update_sytest.go
Normal file
25
keyserver/internal/device_list_update_sytest.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build vw
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite
|
||||
// results in a one-hour wait time from a previous device so the test times out. This is fine for
|
||||
// production, but makes an otherwise passing test fail.
|
||||
const defaultWaitTime = time.Second
|
||||
const hourWaitTime = time.Second
|
||||
|
|
@ -33,16 +33,17 @@ import (
|
|||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/producers"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
type KeyInternalAPI struct {
|
||||
DB storage.Database
|
||||
ThisServer gomatrixserverlib.ServerName
|
||||
FedClient fedsenderapi.KeyserverFederationAPI
|
||||
UserAPI userapi.KeyserverUserAPI
|
||||
Producer *producers.KeyChange
|
||||
Updater *DeviceListUpdater
|
||||
DB storage.Database
|
||||
Cfg *config.KeyServer
|
||||
FedClient fedsenderapi.KeyserverFederationAPI
|
||||
UserAPI userapi.KeyserverUserAPI
|
||||
Producer *producers.KeyChange
|
||||
Updater *DeviceListUpdater
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
|
||||
|
|
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
|||
nested[userID] = val
|
||||
domainToDeviceKeys[string(serverName)] = nested
|
||||
}
|
||||
// claim local keys
|
||||
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
|
||||
for domain, local := range domainToDeviceKeys {
|
||||
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
// claim local keys
|
||||
keys, err := a.DB.ClaimKeys(ctx, local)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
|
|
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
|||
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
||||
}
|
||||
}
|
||||
delete(domainToDeviceKeys, string(a.ThisServer))
|
||||
delete(domainToDeviceKeys, domain)
|
||||
}
|
||||
if len(domainToDeviceKeys) > 0 {
|
||||
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||
|
|
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
}
|
||||
domain := string(serverName)
|
||||
// query local devices
|
||||
if serverName == a.ThisServer {
|
||||
if a.Cfg.Matrix.IsLocalServerName(serverName) {
|
||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
|
|
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
|||
|
||||
domains := map[string]struct{}{}
|
||||
for domain := range domainToDeviceKeys {
|
||||
if domain == string(a.ThisServer) {
|
||||
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
}
|
||||
for domain := range domainToCrossSigningKeys {
|
||||
if domain == string(a.ThisServer) {
|
||||
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
|
|
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
if serverName != a.ThisServer {
|
||||
if !a.Cfg.Matrix.IsLocalServerName(serverName) {
|
||||
continue // ignore remote users
|
||||
}
|
||||
if len(key.KeyJSON) == 0 {
|
||||
|
|
|
|||
|
|
@ -56,10 +56,10 @@ func NewInternalAPI(
|
|||
DB: db,
|
||||
}
|
||||
ap := &internal.KeyInternalAPI{
|
||||
DB: db,
|
||||
ThisServer: cfg.Matrix.ServerName,
|
||||
FedClient: fedClient,
|
||||
Producer: keyChangeProducer,
|
||||
DB: db,
|
||||
Cfg: cfg,
|
||||
FedClient: fedClient,
|
||||
Producer: keyChangeProducer,
|
||||
}
|
||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI) // 8 workers TODO: configurable
|
||||
ap.Updater = updater
|
||||
|
|
|
|||
|
|
@ -47,10 +47,10 @@ const upsertStaleDeviceListSQL = "" +
|
|||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
|
||||
|
|
@ -81,7 +81,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,10 +46,10 @@ const upsertStaleDeviceListSQL = "" +
|
|||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
|
||||
|
|
@ -83,7 +83,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@ type FederationRoomserverAPI interface {
|
|||
QueryBulkStateContentAPI
|
||||
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
||||
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
|
||||
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
|
||||
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
|
||||
QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error
|
||||
|
|
|
|||
56
roomserver/internal/helpers/helpers_test.go
Normal file
56
roomserver/internal/helpers/helpers_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package helpers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Database: %v", err)
|
||||
}
|
||||
return base, db, close
|
||||
}
|
||||
|
||||
func TestIsInvitePendingWithoutNID(t *testing.T) {
|
||||
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat))
|
||||
_ = bob
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
_, db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// store all events
|
||||
var authNIDs []types.EventNID
|
||||
for _, x := range room.Events() {
|
||||
|
||||
evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false)
|
||||
assert.NoError(t, err)
|
||||
authNIDs = append(authNIDs, evNID)
|
||||
}
|
||||
|
||||
// Alice should have no pending invites and should have a NID
|
||||
pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID)
|
||||
assert.NoError(t, err, "failed to get pending invites")
|
||||
assert.False(t, pendingInvite, "unexpected pending invite")
|
||||
|
||||
// Bob should have no pending invites and receive a new NID
|
||||
pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID)
|
||||
assert.NoError(t, err, "failed to get pending invites")
|
||||
assert.False(t, pendingInvite, "unexpected pending invite")
|
||||
})
|
||||
}
|
||||
|
|
@ -23,6 +23,8 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
|
|
@ -409,6 +411,13 @@ func (r *Inputer) processRoomEvent(
|
|||
}
|
||||
}
|
||||
|
||||
// Handle remote room upgrades, e.g. remove published room
|
||||
if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) {
|
||||
if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle remote room upgrade: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// processing this event resulted in an event (which may not be the one we're processing)
|
||||
// being redacted. We are guaranteed to have both sides (the redaction/redacted event),
|
||||
// so notify downstream components to redact this event - they should have it if they've
|
||||
|
|
@ -434,6 +443,13 @@ func (r *Inputer) processRoomEvent(
|
|||
return nil
|
||||
}
|
||||
|
||||
// handleRemoteRoomUpgrade updates published rooms and room aliases
|
||||
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error {
|
||||
oldRoomID := event.RoomID()
|
||||
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
|
||||
return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender())
|
||||
}
|
||||
|
||||
// processStateBefore works out what the state is before the event and
|
||||
// then checks the event auths against the state at the time. It also
|
||||
// tries to determine what the history visibility was of the event at
|
||||
|
|
|
|||
|
|
@ -172,6 +172,6 @@ type Database interface {
|
|||
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
||||
|
||||
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
|
||||
|
||||
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
|
||||
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ func (d *Database) eventStateKeyNIDs(
|
|||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
result := make(map[string]types.EventStateKeyNID)
|
||||
eventStateKeys = util.UniqueStrings(eventStateKeys)
|
||||
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -110,6 +111,27 @@ func (d *Database) eventStateKeyNIDs(
|
|||
for eventStateKey, nid := range nids {
|
||||
result[eventStateKey] = nid
|
||||
}
|
||||
// We received some nids, but are still missing some, work out which and create them
|
||||
if len(eventStateKeys) > len(result) {
|
||||
var nid types.EventStateKeyNID
|
||||
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||
for _, eventStateKey := range eventStateKeys {
|
||||
if _, ok := result[eventStateKey]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
nid, err = d.assignStateKeyNID(ctx, txn, eventStateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result[eventStateKey] = nid
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
|
@ -1243,7 +1265,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||
|
||||
}
|
||||
|
||||
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
|
||||
eventStateKeyNIDMap, err := d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
|
||||
}
|
||||
|
|
@ -1309,7 +1331,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
|
||||
userNIDsMap, err := d.eventStateKeyNIDs(ctx, nil, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1416,6 +1438,36 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error {
|
||||
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// un-publish old room
|
||||
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil {
|
||||
return fmt.Errorf("failed to unpublish room: %w", err)
|
||||
}
|
||||
// publish new room
|
||||
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil {
|
||||
return fmt.Errorf("failed to publish room: %w", err)
|
||||
}
|
||||
|
||||
// Migrate any existing room aliases
|
||||
aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room aliases: %w", err)
|
||||
}
|
||||
|
||||
for _, alias := range aliases {
|
||||
if err = d.RoomAliasesTable.DeleteRoomAlias(ctx, txn, alias); err != nil {
|
||||
return fmt.Errorf("failed to remove room alias: %w", err)
|
||||
}
|
||||
if err = d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, newRoomID, eventSender); err != nil {
|
||||
return fmt.Errorf("failed to set room alias: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
|
||||
// it should live in this package!
|
||||
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ func Enable(
|
|||
base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI,
|
||||
fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache,
|
||||
) error {
|
||||
clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName))
|
||||
clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName), httputil.WithAllowGuests())
|
||||
base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
|
||||
base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func Setup(
|
|||
// TODO: Add AS support for all handlers below.
|
||||
v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return srp.OnIncomingSyncRequest(req, device)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
|
|
@ -59,7 +59,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, rsAPI, cfg, srp, lazyLoadCache)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/event/{eventID}",
|
||||
httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -68,7 +68,7 @@ func Setup(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, syncDB, rsAPI)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/user/{userId}/filter",
|
||||
|
|
@ -93,7 +93,7 @@ func Setup(
|
|||
|
||||
v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return srp.OnIncomingKeyChangeRequest(req, device)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomId}/context/{eventId}",
|
||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
|
|
@ -108,7 +108,7 @@ func Setup(
|
|||
vars["roomId"], vars["eventId"],
|
||||
lazyLoadCache,
|
||||
)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}",
|
||||
|
|
@ -122,7 +122,7 @@ func Setup(
|
|||
req, device, syncDB, rsAPI,
|
||||
vars["roomId"], vars["eventId"], "", "",
|
||||
)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}",
|
||||
|
|
@ -136,7 +136,7 @@ func Setup(
|
|||
req, device, syncDB, rsAPI,
|
||||
vars["roomId"], vars["eventId"], vars["relType"], "",
|
||||
)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}",
|
||||
|
|
@ -150,7 +150,7 @@ func Setup(
|
|||
req, device, syncDB, rsAPI,
|
||||
vars["roomId"], vars["eventId"], vars["relType"], vars["eventType"],
|
||||
)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/search",
|
||||
|
|
@ -191,7 +191,7 @@ func Setup(
|
|||
|
||||
at := req.URL.Query().Get("at")
|
||||
return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at)
|
||||
}),
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
v3mux.Handle("/rooms/{roomID}/joined_members",
|
||||
|
|
|
|||
|
|
@ -264,6 +264,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||
// Work out what the highest stream position is for all of the events in this
|
||||
// room that were returned.
|
||||
latestPosition := r.To
|
||||
if r.Backwards {
|
||||
latestPosition = r.From
|
||||
}
|
||||
updateLatestPosition := func(mostRecentEventID string) {
|
||||
var pos types.StreamPosition
|
||||
if _, pos, err = snapshot.PositionInTopology(ctx, mostRecentEventID); err == nil {
|
||||
|
|
|
|||
|
|
@ -46,3 +46,6 @@ If a device list update goes missing, the server resyncs on the next one
|
|||
# Might be a bug in the test because leaves do appear :-(
|
||||
|
||||
Leaves are present in non-gapped incremental syncs
|
||||
|
||||
# Below test was passing for the wrong reason, failing correctly since #2858
|
||||
New federated private chats get full presence information (SYN-115)
|
||||
|
|
@ -682,7 +682,6 @@ Presence changes are reported to local room members
|
|||
Presence changes are also reported to remote room members
|
||||
Presence changes to UNAVAILABLE are reported to local room members
|
||||
Presence changes to UNAVAILABLE are reported to remote room members
|
||||
New federated private chats get full presence information (SYN-115)
|
||||
/upgrade copies >100 power levels to the new room
|
||||
Room state after a rejected message event is the same as before
|
||||
Room state after a rejected state event is the same as before
|
||||
|
|
@ -760,3 +759,8 @@ Can filter rooms/{roomId}/members
|
|||
Current state appears in timeline in private history with many messages after
|
||||
AS can publish rooms in their own list
|
||||
AS and main public room lists are separate
|
||||
/upgrade preserves direct room state
|
||||
local user has tags copied to the new room
|
||||
remote user has tags copied to the new room
|
||||
/upgrade moves remote aliases to the new room
|
||||
Local and remote users' homeservers remove a room from their public directory on upgrade
|
||||
|
|
@ -78,7 +78,7 @@ type ClientUserAPI interface {
|
|||
QueryAcccessTokenAPI
|
||||
LoginTokenInternalAPI
|
||||
UserLoginAPI
|
||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
|
|
@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct {
|
|||
|
||||
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
||||
type PerformPasswordUpdateRequest struct {
|
||||
Localpart string // Required: The localpart for this account.
|
||||
Password string // Required: The new password to set.
|
||||
LogoutDevices bool // Optional: Whether to log out all user devices.
|
||||
Localpart string // Required: The localpart for this account.
|
||||
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
|
||||
Password string // Required: The new password to set.
|
||||
LogoutDevices bool // Optional: Whether to log out all user devices.
|
||||
}
|
||||
|
||||
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
||||
|
|
@ -518,7 +519,8 @@ const (
|
|||
)
|
||||
|
||||
type QueryPushersRequest struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryPushersResponse struct {
|
||||
|
|
@ -526,14 +528,16 @@ type QueryPushersResponse struct {
|
|||
}
|
||||
|
||||
type PerformPusherSetRequest struct {
|
||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||
Localpart string
|
||||
Append bool `json:"append"`
|
||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
Append bool `json:"append"`
|
||||
}
|
||||
|
||||
type PerformPusherDeletionRequest struct {
|
||||
Localpart string
|
||||
SessionID int64
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
SessionID int64
|
||||
}
|
||||
|
||||
// Pusher represents a push notification subscriber
|
||||
|
|
@ -571,10 +575,11 @@ type QueryPushRulesResponse struct {
|
|||
}
|
||||
|
||||
type QueryNotificationsRequest struct {
|
||||
Localpart string `json:"localpart"` // Required.
|
||||
From string `json:"from,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Only string `json:"only,omitempty"`
|
||||
Localpart string `json:"localpart"` // Required.
|
||||
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
|
||||
From string `json:"from,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Only string `json:"only,omitempty"`
|
||||
}
|
||||
|
||||
type QueryNotificationsResponse struct {
|
||||
|
|
@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct {
|
|||
Changed bool `json:"changed"`
|
||||
}
|
||||
|
||||
type QueryNumericLocalpartRequest struct {
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryNumericLocalpartResponse struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
type QueryAccountAvailabilityRequest struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryAccountAvailabilityResponse struct {
|
||||
|
|
@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct {
|
|||
}
|
||||
|
||||
type QueryAccountByPasswordRequest struct {
|
||||
Localpart, PlaintextPassword string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
PlaintextPassword string
|
||||
}
|
||||
|
||||
type QueryAccountByPasswordResponse struct {
|
||||
|
|
@ -638,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct {
|
|||
}
|
||||
|
||||
type QueryLocalpartForThreePIDResponse struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryThreePIDsForLocalpartRequest struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryThreePIDsForLocalpartResponse struct {
|
||||
|
|
@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
|
|||
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
|
||||
|
||||
type PerformSaveThreePIDAssociationRequest struct {
|
||||
ThreePID, Localpart, Medium string
|
||||
ThreePID string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
Medium string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
|
|||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
|
||||
err := t.Impl.QueryNumericLocalpart(ctx, res)
|
||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
|
||||
err := t.Impl.QueryNumericLocalpart(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
|||
return false
|
||||
}
|
||||
|
||||
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
||||
updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer")
|
||||
return false
|
||||
|
|
@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
|||
if !updated {
|
||||
return true
|
||||
}
|
||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
|
||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,12 +2,16 @@ package consumers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
|
@ -185,13 +189,115 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
|
|||
}
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
|
||||
for _, membership := range localMembers {
|
||||
// Copy any existing push rules from old -> new room
|
||||
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// preserve m.direct room state
|
||||
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copy existing m.tag entries, if any
|
||||
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||
pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query pushrules for user: %w", err)
|
||||
}
|
||||
if pushRules == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, roomRule := range pushRules.Global.Room {
|
||||
if roomRule.RuleID != oldRoomID {
|
||||
continue
|
||||
}
|
||||
cpRool := *roomRule
|
||||
cpRool.RuleID = newRoomID
|
||||
pushRules.Global.Room = append(pushRules.Global.Room, &cpRool)
|
||||
rules, err := json.Marshal(pushRules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
|
||||
return fmt.Errorf("failed to update pushrules: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
|
||||
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
|
||||
// this is most likely not a DM, so skip updating m.direct state
|
||||
if roomSize > 2 {
|
||||
return nil
|
||||
}
|
||||
// Get direct message state
|
||||
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get m.direct from database: %w", err)
|
||||
}
|
||||
directChats := gjson.ParseBytes(directChatsRaw)
|
||||
newDirectChats := make(map[string][]string)
|
||||
// iterate over all userID -> roomIDs
|
||||
directChats.ForEach(func(userID, roomIDs gjson.Result) bool {
|
||||
var found bool
|
||||
for _, roomID := range roomIDs.Array() {
|
||||
newDirectChats[userID.Str] = append(newDirectChats[userID.Str], roomID.Str)
|
||||
// add the new roomID to m.direct
|
||||
if roomID.Str == oldRoomID {
|
||||
found = true
|
||||
newDirectChats[userID.Str] = append(newDirectChats[userID.Str], newRoomID)
|
||||
}
|
||||
}
|
||||
// Only hit the database if we found the old room as a DM for this user
|
||||
if found {
|
||||
var data []byte
|
||||
data, err = json.Marshal(newDirectChats)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update m.direct state")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||
tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
if tag == nil {
|
||||
return nil
|
||||
}
|
||||
return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
|
||||
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.localRoomMembers: %w", err)
|
||||
}
|
||||
|
||||
if event.Type() == gomatrixserverlib.MRoomMember {
|
||||
switch {
|
||||
case event.Type() == gomatrixserverlib.MRoomMember:
|
||||
cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll)
|
||||
var member *localMembership
|
||||
member, err = newLocalMembership(&cevent)
|
||||
|
|
@ -203,6 +309,15 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gom
|
|||
// should also be pushed to the target user.
|
||||
members = append(members, member)
|
||||
}
|
||||
case event.Type() == "m.room.tombstone" && event.StateKeyEquals(""):
|
||||
// Handle room upgrades
|
||||
oldRoomID := event.RoomID()
|
||||
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
|
||||
if err = s.handleRoomUpgrade(ctx, oldRoomID, newRoomID, members, roomSize); err != nil {
|
||||
// while inconvenient, this shouldn't stop us from sending push notifications
|
||||
log.WithError(err).Errorf("UserAPI: failed to handle room upgrade for users")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO: run in parallel with localRoomMembers.
|
||||
|
|
@ -377,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
|
|||
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
|
||||
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.evaluatePushRules: %w", err)
|
||||
}
|
||||
a, tweaks, err := pushrules.ActionsToTweaks(actions)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("pushrules.ActionsToTweaks: %w", err)
|
||||
}
|
||||
// TODO: support coalescing.
|
||||
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
|
||||
|
|
@ -393,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
|||
return nil
|
||||
}
|
||||
|
||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
|
||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.localPushDevices: %w", err)
|
||||
}
|
||||
|
||||
n := &api.Notification{
|
||||
|
|
@ -412,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
|||
RoomID: event.RoomID(),
|
||||
TS: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||
}
|
||||
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
|
||||
return err
|
||||
if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil {
|
||||
return fmt.Errorf("s.db.InsertNotification: %w", err)
|
||||
}
|
||||
|
||||
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err)
|
||||
}
|
||||
|
||||
// We do this after InsertNotification. Thus, this should always return >=1.
|
||||
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
|
||||
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.db.GetNotificationCount: %w", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
|
|
@ -474,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
|||
}
|
||||
|
||||
if len(rejected) > 0 {
|
||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
|
||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -491,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
|||
}
|
||||
|
||||
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
|
||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list")
|
||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -506,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
|||
return nil, fmt.Errorf("user %s is ignored", sender)
|
||||
}
|
||||
}
|
||||
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart)
|
||||
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -578,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
|
|||
|
||||
// localPushDevices pushes to the configured devices of a local
|
||||
// user. The map keys are [url][format].
|
||||
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
|
||||
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, "", fmt.Errorf("util.GetPushDevices: %w", err)
|
||||
}
|
||||
|
||||
var profileTag string
|
||||
|
|
@ -676,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
|
|||
}
|
||||
|
||||
// deleteRejectedPushers deletes the pushers associated with the given devices.
|
||||
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
|
||||
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id0": devices[0].AppID,
|
||||
|
|
@ -684,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
|
|||
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
|
||||
|
||||
for _, d := range devices {
|
||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
|
||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
}).WithError(err).Errorf("Unable to delete rejected pusher")
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
|||
if req.DataType == "" {
|
||||
return fmt.Errorf("data type must not be empty")
|
||||
}
|
||||
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
|
||||
return fmt.Errorf("failed to save account data: %w", err)
|
||||
}
|
||||
|
|
@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
|||
return nil
|
||||
}
|
||||
|
||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
|
||||
return err
|
||||
|
|
@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
|||
return nil
|
||||
}
|
||||
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
|
||||
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
|
||||
return err
|
||||
}
|
||||
|
|
@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
if serverName == "" {
|
||||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
// XXXX: Use the server name here
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
return fmt.Errorf("server name %s is not local", serverName)
|
||||
}
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
|
@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
return nil
|
||||
}
|
||||
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
||||
return err
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
|
||||
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
|
||||
}
|
||||
|
||||
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
|
||||
|
|
@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
||||
if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
|
||||
return fmt.Errorf("server name %s is not local", req.ServerName)
|
||||
}
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.LogoutDevices {
|
||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
|
||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
|||
if serverName == "" {
|
||||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
_ = serverName
|
||||
// XXXX: Use the server name here
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
return fmt.Errorf("server name %s is not local", serverName)
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"localpart": req.Localpart,
|
||||
"device_id": req.DeviceID,
|
||||
"display_name": req.DeviceDisplayName,
|
||||
}).Info("PerformDeviceCreation")
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
|||
deletedDeviceIDs := req.DeviceIDs
|
||||
if len(req.DeviceIDs) == 0 {
|
||||
var devices []api.Device
|
||||
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
||||
devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
|
||||
for _, d := range devices {
|
||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||
}
|
||||
} else {
|
||||
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||
err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
|
|||
req *api.PerformLastSeenUpdateRequest,
|
||||
res *api.PerformLastSeenUpdateResponse,
|
||||
) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("server name %s is not local", domain)
|
||||
}
|
||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return err
|
||||
}
|
||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("server name %s is not local", domain)
|
||||
}
|
||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
|
||||
if err == sql.ErrNoRows {
|
||||
res.DeviceExists = false
|
||||
return nil
|
||||
|
|
@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
|||
return nil
|
||||
}
|
||||
|
||||
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||
err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||
return err
|
||||
|
|
@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
|
||||
}
|
||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
|
||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
|
@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
|
||||
}
|
||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
|
||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
if req.DataType != "" {
|
||||
var data json.RawMessage
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
return nil
|
||||
}
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local)
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return nil
|
||||
}
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
|||
AccountType: api.AccountTypeAppService,
|
||||
}
|
||||
|
||||
localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
||||
localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if localpart != "" { // AS is masquerading as another user
|
||||
// Verify that the user is registered
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
|
||||
// Verify that the account exists and either appServiceID matches or
|
||||
// it belongs to the appservice user namespaces
|
||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||
|
|
@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
|||
return err
|
||||
}
|
||||
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
|
||||
res.AccountDeactivated = err == nil
|
||||
return err
|
||||
}
|
||||
|
|
@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
|
|||
if req.Only == "highlight" {
|
||||
filter = tables.HighlightNotifications
|
||||
}
|
||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
|
||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
|||
}
|
||||
}
|
||||
if req.Pusher.Kind == "" {
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
|
||||
}
|
||||
if req.Pusher.PushKeyTS == 0 {
|
||||
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||
}
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range pushers {
|
||||
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
|
||||
if pushers[i].SessionID != req.SessionID {
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
|
|||
|
||||
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
||||
var err error
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
||||
}
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart)
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query push rules: %w", err)
|
||||
}
|
||||
|
|
@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
|
||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
|
||||
res.Profile = profile
|
||||
res.Changed = changed
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx)
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
|
|||
|
||||
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
||||
var err error
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
|
||||
switch err {
|
||||
case sql.ErrNoRows: // user does not exist
|
||||
return nil
|
||||
|
|
@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
|
||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
|
||||
res.Profile = profile
|
||||
res.Changed = changed
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
||||
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||
localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Localpart = localpart
|
||||
res.ServerName = domain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
|
||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
|
||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
|
||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
|
||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
|
||||
}
|
||||
|
||||
const pushRulesAccountDataType = "m.push_rules"
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
||||
}
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
|
||||
res.Data = nil
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
|
|||
|
||||
func (h *httpUserInternalAPI) QueryNumericLocalpart(
|
||||
ctx context.Context,
|
||||
request *api.QueryNumericLocalpartRequest,
|
||||
response *api.QueryNumericLocalpartResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
|
||||
h.httpClient, ctx, &struct{}{}, response,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,12 +15,9 @@
|
|||
package inthttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// nolint: gocyclo
|
||||
|
|
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
|
||||
)
|
||||
|
||||
// TODO: Look at the shape of this
|
||||
internalAPIMux.Handle(QueryNumericLocalpartPath,
|
||||
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
|
||||
response := api.QueryNumericLocalpartResponse{}
|
||||
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
internalAPIMux.Handle(
|
||||
QueryNumericLocalpartPath,
|
||||
httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
|
|
|
|||
|
|
@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err
|
|||
// GetAndSendNotificationData reads the database and sends data about unread
|
||||
// notifications to the Sync API server.
|
||||
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
|
||||
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,40 +29,40 @@ import (
|
|||
)
|
||||
|
||||
type Profile interface {
|
||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||
GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
||||
SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||
}
|
||||
|
||||
type Account interface {
|
||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
|
||||
}
|
||||
|
||||
type AccountData interface {
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
// localpart, room ID and type.
|
||||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error)
|
||||
GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||
QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
|
||||
}
|
||||
|
||||
type Device interface {
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
|
||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
|
|
@ -70,12 +70,12 @@ type Device interface {
|
|||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||
RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
|
||||
}
|
||||
|
||||
type KeyBackup interface {
|
||||
|
|
@ -107,26 +107,26 @@ type OpenID interface {
|
|||
}
|
||||
|
||||
type Pusher interface {
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
||||
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||
}
|
||||
|
||||
type ThreePID interface {
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||
}
|
||||
|
||||
type Notification interface {
|
||||
InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
|
||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
|
||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
||||
DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error)
|
||||
SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error)
|
||||
GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||
GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error)
|
||||
GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
|
@ -29,27 +30,28 @@ const accountDataSchema = `
|
|||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
room_id TEXT,
|
||||
-- The account data type
|
||||
type TEXT NOT NULL,
|
||||
-- The account data content
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(localpart, room_id, type)
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
|
||||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||
|
||||
type accountDataStatements struct {
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
|
|
@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
|
@ -34,7 +35,8 @@ const accountsSchema = `
|
|||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
|
|
@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
|||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
|
|
@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("insertAccountStmt: %w", err)
|
||||
}
|
||||
|
||||
return &api.Account{
|
||||
Localpart: localpart,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
ServerName: serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
|
@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
|||
acc.AppServiceID = appserviceIDPtr.String
|
||||
}
|
||||
|
||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
acc.ServerName = s.serverName
|
||||
|
||||
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||
return id + 1, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,81 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var serverNamesTables = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_devices",
|
||||
"userapi_notifications",
|
||||
"userapi_openid_tokens",
|
||||
"userapi_profiles",
|
||||
"userapi_pushers",
|
||||
"userapi_threepids",
|
||||
}
|
||||
|
||||
// These tables have a PRIMARY KEY constraint which we need to drop so
|
||||
// that we can recreate a new unique index that contains the server name.
|
||||
// If the new key doesn't exist (i.e. the database was created before the
|
||||
// table rename migration) we'll try to drop the old one instead.
|
||||
var serverNamesDropPK = map[string]string{
|
||||
"userapi_accounts": "account_accounts",
|
||||
"userapi_account_datas": "account_data",
|
||||
"userapi_profiles": "account_profiles",
|
||||
}
|
||||
|
||||
// These indices are out of date so let's drop them. They will get recreated
|
||||
// automatically.
|
||||
var serverNamesDropIndex = []string{
|
||||
"userapi_pusher_localpart_idx",
|
||||
"userapi_pusher_app_id_pushkey_localpart_idx",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
for newTable, oldTable := range serverNamesDropPK {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop new PK from %q error: %w", newTable, err)
|
||||
}
|
||||
q = fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop old PK from %q error: %w", newTable, err)
|
||||
}
|
||||
}
|
||||
for _, index := range serverNamesDropIndex {
|
||||
q := fmt.Sprintf(
|
||||
"DROP INDEX IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(index),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop index %q error: %w", index, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
|
@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
|||
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
|
||||
-- migration to different domain names easier.
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The display name, human friendlier than device_id and updatable
|
||||
|
|
@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
|||
);
|
||||
|
||||
-- Device IDs must be unique for a given user.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
||||
"INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
|
||||
" RETURNING session_id"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||
|
||||
type devicesStatements struct {
|
||||
insertDeviceStmt *sql.Stmt
|
||||
|
|
@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
|
|||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, err
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||
}
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
|
|
@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice(
|
|||
|
||||
// deleteDevice removes a single device by id and user localpart.
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||
// Returns an error if the execution failed.
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
||||
_, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices))
|
||||
return err
|
||||
}
|
||||
|
||||
// deleteDevicesByLocalpart removes all devices for the
|
||||
// given user localpart.
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
|||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
dev.AccessToken = accessToken
|
||||
}
|
||||
return &dev, err
|
||||
|
|
@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
|||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName, ip sql.NullString
|
||||
var lastseenTS sql.NullInt64
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
|
|
@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
var devices []api.Device
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
var lastseents sql.NullInt64
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
|
|
@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
|
||||
if err != nil {
|
||||
return devices, err
|
||||
|
|
@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
|||
dev.UserAgent = useragent.String
|
||||
}
|
||||
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ const notificationSchema = `
|
|||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
stream_pos BIGINT NOT NULL,
|
||||
|
|
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
|||
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||
`
|
||||
|
||||
const insertNotificationSQL = "" +
|
||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
|
||||
const deleteNotificationsUpToSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||
|
||||
const updateNotificationReadSQL = "" +
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||
|
||||
const selectNotificationSQL = "" +
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||
|
||||
const selectNotificationCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read"
|
||||
|
||||
const selectRoomNotificationCountsSQL = "" +
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
||||
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||
|
||||
const cleanNotificationsSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE" +
|
||||
|
|
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
|||
}
|
||||
|
||||
// Insert inserts a notification into the database.
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
roomID, tsMS := n.RoomID, n.TS
|
||||
nn := *n
|
||||
// Clears out fields that have their own columns to (1) shrink the
|
||||
|
|
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
|||
}
|
||||
|
||||
// UpdateRead updates the "read" value for an event.
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
|||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
|
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
|||
return notifs, maxID, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
|||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
insertTokenStmt *sql.Stmt
|
||||
|
|
@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
|||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&localpart, &serverName,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
|
|
|
|||
|
|
@ -23,42 +23,46 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The display name for this account
|
||||
display_name TEXT,
|
||||
-- The URL of the avatar for this account
|
||||
avatar_url TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE userapi_profiles AS new" +
|
||||
" SET avatar_url = $1" +
|
||||
" FROM userapi_profiles AS old" +
|
||||
" WHERE new.localpart = $2" +
|
||||
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE userapi_profiles AS new" +
|
||||
" SET display_name = $1" +
|
||||
" FROM userapi_profiles AS old" +
|
||||
" WHERE new.localpart = $2" +
|
||||
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||
" RETURNING new.avatar_url, old.display_name <> new.display_name"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
serverNoticesLocalpart string
|
||||
|
|
@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
|
|||
}
|
||||
|
||||
func (s *profilesStatements) InsertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
|||
}
|
||||
|
||||
func (s *profilesStatements) SetAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
AvatarURL: avatarURL,
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
AvatarURL: avatarURL,
|
||||
}
|
||||
var changed bool
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
|
||||
err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed)
|
||||
return profile, changed, err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SetDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
DisplayName: displayName,
|
||||
}
|
||||
var changed bool
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
|
||||
err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed)
|
||||
return profile, changed, err
|
||||
}
|
||||
|
||||
|
|
@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var profile authtypes.Profile
|
||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.Localpart != s.serverNoticesLocalpart {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
|
|
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
id BIGSERIAL PRIMARY KEY,
|
||||
-- The Matrix user ID localpart for this pusher
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
session_id BIGINT DEFAULT NULL,
|
||||
profile_tag TEXT,
|
||||
kind TEXT NOT NULL,
|
||||
|
|
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
|
|
@ -95,18 +97,19 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
|
|
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
|
|||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
|
@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
||||
}
|
||||
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||
}
|
||||
devicesTable, err := NewPostgresDevicesTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
|
||||
|
|
@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
|
@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
|
|||
medium TEXT NOT NULL DEFAULT 'email',
|
||||
-- The localpart of the Matrix user ID associated to this 3PID
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
|
@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
|||
|
||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
return "", "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
|||
}
|
||||
|
||||
func (s *threepidStatements) InsertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -68,9 +68,10 @@ const (
|
|||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) (*api.Account, error) {
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword(
|
|||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||
func (d *Database) GetProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (profile *authtypes.Profile, changed bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL(
|
|||
// SetDisplayName updates the display name of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (profile *authtypes.Profile, changed bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -117,14 +123,15 @@ func (d *Database) SetDisplayName(
|
|||
|
||||
// SetPassword sets the account password to the given hash.
|
||||
func (d *Database) SetPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) error {
|
||||
hash, err := d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -132,21 +139,22 @@ func (d *Database) SetPassword(
|
|||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (acc *api.Account, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// For guest accounts, we create a new numeric local part
|
||||
if accountType == api.AccountTypeGuest {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
|
||||
}
|
||||
localpart = strconv.FormatInt(numLocalpart, 10)
|
||||
plaintextPassword = ""
|
||||
appserviceID = ""
|
||||
}
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
@ -155,7 +163,9 @@ func (d *Database) CreateAccount(
|
|||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
|
|
@ -167,28 +177,28 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
|
||||
return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
|
||||
}
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, err
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (d *Database) QueryPushRules(
|
||||
ctx context.Context,
|
||||
localpart string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*pushrules.AccountRuleSets, error) {
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules")
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -196,13 +206,13 @@ func (d *Database) QueryPushRules(
|
|||
// If we didn't find any default push rules then we should just generate some
|
||||
// fresh ones.
|
||||
if len(data) == 0 {
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
|
||||
}
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
|
||||
return fmt.Errorf("failed to save default push rules: %w", dbErr)
|
||||
}
|
||||
return nil
|
||||
|
|
@ -225,22 +235,23 @@ func (d *Database) QueryPushRules(
|
|||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccountData returns account data related to a given localpart
|
||||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart)
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
|
|
@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.AccountDatas.SelectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
ctx, localpart, serverName, roomID, dataType,
|
||||
)
|
||||
}
|
||||
|
||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
ctx context.Context, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
|
|
@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
|||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) SaveThreePIDAssociation(
|
||||
ctx context.Context, threepid, localpart, medium string,
|
||||
ctx context.Context, threepid string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
medium string,
|
||||
) (err error) {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
ctx, txn, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
|
|||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
|||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) GetLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
}
|
||||
|
||||
|
|
@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID(
|
|||
// If no association is known for this user, returns an empty slice.
|
||||
// Returns an error if there was an issue talking to the database.
|
||||
func (d *Database) GetThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
|
||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// CheckAccountAvailability checks if the username/localpart is already present
|
||||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
|
|
@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
|||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||
// This function assumes the request is authenticated or the account data is used only internally.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
// try to get the account with lowercase localpart (majority)
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
|
||||
}
|
||||
return acc, err
|
||||
}
|
||||
|
|
@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
|||
}
|
||||
|
||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart)
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
token, localpart string,
|
||||
token, userID string,
|
||||
) (int64, error) {
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
|
@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken(
|
|||
// GetDeviceByID returns the device matching the given ID.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
|
||||
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Device, error) {
|
||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
|
|
@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
|
|||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
|
|
@ -588,7 +610,7 @@ func (d *Database) CreateDevice(
|
|||
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
|
|
@ -614,10 +636,12 @@ func generateDeviceID() (string, error) {
|
|||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -626,10 +650,12 @@ func (d *Database) UpdateDevice(
|
|||
// If the devices don't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
@ -640,14 +666,16 @@ func (d *Database) RemoveDevices(
|
|||
// database matching the given user ID localpart.
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices(
|
|||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
|
||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
|||
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||
}
|
||||
|
||||
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||
return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
|
||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
|
||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
|
||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
|
||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
|
||||
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
|
||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
|
||||
return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
|
||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
|
||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
|
||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
||||
|
|
@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (d *Database) UpsertPusher(
|
||||
ctx context.Context, p api.Pusher, localpart string,
|
||||
ctx context.Context, p api.Pusher,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
data, err := json.Marshal(p.Data)
|
||||
if err != nil {
|
||||
|
|
@ -766,25 +795,26 @@ func (d *Database) UpsertPusher(
|
|||
p.ProfileTag,
|
||||
p.Language,
|
||||
string(data),
|
||||
localpart)
|
||||
localpart,
|
||||
serverName)
|
||||
})
|
||||
}
|
||||
|
||||
// GetPushers returns the pushers matching the given localpart.
|
||||
func (d *Database) GetPushers(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart)
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
|
||||
}
|
||||
|
||||
// RemovePusher deletes one pusher
|
||||
// Invoked when `append` is true and `kind` is null in
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
||||
func (d *Database) RemovePusher(
|
||||
ctx context.Context, appid, pushkey, localpart string,
|
||||
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
|
@ -28,27 +29,28 @@ const accountDataSchema = `
|
|||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
room_id TEXT,
|
||||
-- The account data type
|
||||
type TEXT NOT NULL,
|
||||
-- The account data content
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(localpart, room_id, type)
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
|
||||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
|
|
@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,8 @@ const accountsSchema = `
|
|||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
|
|
@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
|||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
|
|
@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
|
|||
|
||||
return &api.Account{
|
||||
Localpart: localpart,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
ServerName: serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
|
@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
|||
acc.AppServiceID = appserviceIDPtr.String
|
||||
}
|
||||
|
||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
acc.ServerName = s.serverName
|
||||
|
||||
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||
if err == sql.ErrNoRows {
|
||||
return 1, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
|||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
|||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
|||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
|
|
|||
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var serverNamesTables = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_devices",
|
||||
"userapi_notifications",
|
||||
"userapi_openid_tokens",
|
||||
"userapi_profiles",
|
||||
"userapi_pushers",
|
||||
"userapi_threepids",
|
||||
}
|
||||
|
||||
// These tables have a PRIMARY KEY constraint which we need to drop so
|
||||
// that we can recreate a new unique index that contains the server name.
|
||||
var serverNamesDropPK = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_profiles",
|
||||
}
|
||||
|
||||
// These indices are out of date so let's drop them. They will get recreated
|
||||
// automatically.
|
||||
var serverNamesDropIndex = []string{
|
||||
"userapi_pusher_localpart_idx",
|
||||
"userapi_pusher_app_id_pushkey_localpart_idx",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
var c int
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
|
||||
continue
|
||||
}
|
||||
q = fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
|
||||
logrus.Infof("Table %s already has column, skipping", table)
|
||||
continue
|
||||
}
|
||||
if c == 0 {
|
||||
q = fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, table := range serverNamesDropPK {
|
||||
q := fmt.Sprintf(
|
||||
"SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
var c int
|
||||
var sql string
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
|
||||
continue
|
||||
}
|
||||
q = fmt.Sprintf(`
|
||||
%s; -- create temporary table
|
||||
INSERT INTO %s SELECT * FROM %s; -- copy data
|
||||
DROP TABLE %s; -- drop original table
|
||||
ALTER TABLE %s RENAME TO %s; -- rename new table
|
||||
`,
|
||||
strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
|
||||
table+"_tmp", table, // copy data
|
||||
table, // drop original table
|
||||
table+"_tmp", table, // rename new table
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop PK from %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
for _, index := range serverNamesDropIndex {
|
||||
q := fmt.Sprintf(
|
||||
"DROP INDEX IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(index),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop index %q error: %w", index, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
|||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
|
||||
UNIQUE (localpart, device_id)
|
||||
UNIQUE (localpart, server_name, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
|
|
@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
|
|
@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
|
|||
return nil, err
|
||||
}
|
||||
sessionID++
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
|
|
@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
|
|||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||
orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
|
||||
prep, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+1)
|
||||
params := make([]interface{}, len(devices)+2)
|
||||
params[0] = localpart
|
||||
params[1] = serverName
|
||||
for i, v := range devices {
|
||||
params[i+1] = v
|
||||
params[i+2] = v
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
|||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
dev.AccessToken = accessToken
|
||||
}
|
||||
return &dev, err
|
||||
|
|
@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
|||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName, ip sql.NullString
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
var lastseenTS sql.NullInt64
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
|
|
@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
|
|||
}
|
||||
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
|
||||
if err != nil {
|
||||
return devices, err
|
||||
|
|
@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
|||
dev.UserAgent = useragent.String
|
||||
}
|
||||
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
||||
|
|
@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
var devices []api.Device
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
var displayName sql.NullString
|
||||
var lastseents sql.NullInt64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
|
|
@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ const notificationSchema = `
|
|||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
stream_pos BIGINT NOT NULL,
|
||||
|
|
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
|||
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||
`
|
||||
|
||||
const insertNotificationSQL = "" +
|
||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
|
||||
const deleteNotificationsUpToSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||
|
||||
const updateNotificationReadSQL = "" +
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||
|
||||
const selectNotificationSQL = "" +
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||
|
||||
const selectNotificationCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read"
|
||||
|
||||
const selectRoomNotificationCountsSQL = "" +
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
||||
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||
|
||||
const cleanNotificationsSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE" +
|
||||
|
|
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
|||
}
|
||||
|
||||
// Insert inserts a notification into the database.
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
roomID, tsMS := n.RoomID, n.TS
|
||||
nn := *n
|
||||
// Clears out fields that have their own columns to (1) shrink the
|
||||
|
|
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
|||
}
|
||||
|
||||
// UpdateRead updates the "read" value for an event.
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
|
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
|||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
|
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
|||
return notifs, maxID, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
|
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
|||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
db *sql.DB
|
||||
|
|
@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
|
|||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
|||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&localpart, &serverName,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
|
|
|
|||
|
|
@ -23,36 +23,40 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The display name for this account
|
||||
display_name TEXT,
|
||||
-- The URL of the avatar for this account
|
||||
avatar_url TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||
" RETURNING display_name"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||
" RETURNING avatar_url"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
db *sql.DB
|
||||
|
|
@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
|
|||
}
|
||||
|
||||
func (s *profilesStatements) InsertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
|||
}
|
||||
|
||||
func (s *profilesStatements) SetAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
AvatarURL: avatarURL,
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
AvatarURL: avatarURL,
|
||||
}
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return old, false, err
|
||||
}
|
||||
|
|
@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
|
|||
return old, false, nil
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
|
||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
|
||||
return profile, true, err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SetDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
DisplayName: displayName,
|
||||
}
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return old, false, err
|
||||
}
|
||||
|
|
@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
|
|||
return old, false, nil
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
|
||||
err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
|
||||
return profile, true, err
|
||||
}
|
||||
|
||||
|
|
@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var profile authtypes.Profile
|
||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.Localpart != s.serverNoticesLocalpart {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
|
|
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The Matrix user ID localpart for this pusher
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
session_id BIGINT DEFAULT NULL,
|
||||
profile_tag TEXT,
|
||||
kind TEXT NOT NULL,
|
||||
|
|
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
|
|
@ -95,18 +97,19 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
|
|
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
|
|||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
|
@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||
}
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
}
|
||||
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
|
||||
|
|
@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
|
@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
|
|||
medium TEXT NOT NULL DEFAULT 'email',
|
||||
-- The localpart of the Matrix user ID associated to this 3PID
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
|
@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
|||
|
||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
return "", "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
|||
}
|
||||
|
||||
func (s *threepidStatements) InsertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
|
|||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
room := test.NewRoom(t, alice)
|
||||
events := room.Events()
|
||||
|
||||
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
||||
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
||||
err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
||||
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
|
||||
assert.NoError(t, err, "unable to get account data by type")
|
||||
assert.Equal(t, contentRoom, accountData)
|
||||
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
||||
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
||||
|
|
@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) {
|
|||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
var accGet *api.Account
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
|
||||
assert.NoError(t, err, "failed to get account by password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to get account by localpart")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// check account availability
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, false, available)
|
||||
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, true, available)
|
||||
|
||||
// get guest account numeric aliceLocalpart
|
||||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||
// Create a new account to verify the numeric localpart is updated
|
||||
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
second, err := db.GetNewNumericLocalpart(ctx)
|
||||
second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, second, first)
|
||||
|
||||
// update password for alice
|
||||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
||||
err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.NoError(t, err, "failed to update password")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.NoError(t, err, "failed to get account by new password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// deactivate account
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to deactivate account")
|
||||
// This should fail now, as the account is deactivated
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.Error(t, err, "expected an error, got none")
|
||||
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
|
||||
assert.Error(t, err, "expected an error for non existent localpart")
|
||||
|
||||
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
||||
// if there's already a user without a localpart in the database
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test getting a numeric localpart, with an existing user without a localpart
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
|
||||
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
|
||||
_, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Now try to create a new guest user
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Devices(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
deviceID := util.RandomString(8)
|
||||
accessToken := util.RandomString(16)
|
||||
|
|
@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) {
|
|||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
|
||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
|
|
@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) {
|
|||
|
||||
// create a device without existing device ID
|
||||
accessToken = util.RandomString(16)
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
// Get devices
|
||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||
assert.NoError(t, err, "unable to get devices by localpart")
|
||||
assert.Equal(t, 2, len(devices))
|
||||
deviceIDs := make([]string, 0, len(devices))
|
||||
|
|
@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
|
|||
|
||||
// Update device
|
||||
newName := "new display name"
|
||||
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
|
||||
err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
|
||||
assert.NoError(t, err, "unable to update device displayname")
|
||||
updatedAfterTimestamp := time.Now().Unix()
|
||||
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web")
|
||||
err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
|
||||
assert.NoError(t, err, "unable to update device last seen")
|
||||
|
||||
deviceWithID.DisplayName = newName
|
||||
deviceWithID.LastSeenIP = "127.0.0.1"
|
||||
gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
|
||||
gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 2, len(devices))
|
||||
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
|
||||
|
|
@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) {
|
|||
// create one more device and remove the devices step by step
|
||||
newDeviceID := util.RandomString(16)
|
||||
accessToken = util.RandomString(16)
|
||||
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
||||
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create new device")
|
||||
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 3, len(devices))
|
||||
|
||||
err = db.RemoveDevices(ctx, localpart, deviceIDs)
|
||||
err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
|
||||
assert.NoError(t, err, "unable to remove devices")
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 1, len(devices))
|
||||
|
||||
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
|
||||
deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
|
||||
assert.NoError(t, err, "unable to remove all devices")
|
||||
assert.Equal(t, 1, len(deleted))
|
||||
assert.Equal(t, newDeviceID, deleted[0].ID)
|
||||
|
|
@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
|
|||
|
||||
func Test_Profile(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
|
@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) {
|
|||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get profile by localpart")
|
||||
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
||||
wantProfile := &authtypes.Profile{
|
||||
Localpart: aliceLocalpart,
|
||||
ServerName: string(aliceDomain),
|
||||
}
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
|
||||
// set avatar & displayname
|
||||
wantProfile.DisplayName = "Alice"
|
||||
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
||||
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice")
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
assert.NoError(t, err, "unable to set displayname")
|
||||
assert.True(t, changed)
|
||||
|
||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||
assert.NoError(t, err, "unable to set avatar url")
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
assert.True(t, changed)
|
||||
|
||||
// Setting the same avatar again doesn't change anything
|
||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||
assert.NoError(t, err, "unable to set avatar url")
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
assert.False(t, changed)
|
||||
|
|
@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) {
|
|||
|
||||
func Test_Pusher(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
|
@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) {
|
|||
ProfileTag: util.RandomString(8),
|
||||
Language: util.RandomString(2),
|
||||
}
|
||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
|
||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to upsert pusher")
|
||||
|
||||
// check it was actually persisted
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, i+1, len(gotPushers))
|
||||
assert.Equal(t, wantPusher, gotPushers[i])
|
||||
|
|
@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) {
|
|||
}
|
||||
|
||||
// remove single pusher
|
||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
|
||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 1, len(gotPushers))
|
||||
|
||||
// remove last pusher
|
||||
err = db.RemovePushers(ctx, appID, pushKeys[1])
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 0, len(gotPushers))
|
||||
})
|
||||
|
|
@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) {
|
|||
|
||||
func Test_ThreePID(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
|
@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) {
|
|||
defer close()
|
||||
threePID := util.RandomString(8)
|
||||
medium := util.RandomString(8)
|
||||
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
|
||||
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
|
||||
assert.NoError(t, err, "unable to save threepid association")
|
||||
|
||||
// get the stored threepid
|
||||
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||
gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||
assert.NoError(t, err, "unable to get localpart for threepid")
|
||||
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||
assert.Equal(t, aliceDomain, gotDomain)
|
||||
|
||||
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||
assert.Equal(t, 1, len(threepids))
|
||||
assert.Equal(t, authtypes.ThreePID{
|
||||
|
|
@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) {
|
|||
assert.NoError(t, err, "unexpected error")
|
||||
|
||||
// verify it was deleted
|
||||
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||
assert.Equal(t, 0, len(threepids))
|
||||
})
|
||||
|
|
@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) {
|
|||
|
||||
func Test_Notification(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
room := test.NewRoom(t, alice)
|
||||
room2 := test.NewRoom(t, alice)
|
||||
|
|
@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) {
|
|||
RoomID: roomID,
|
||||
TS: gomatrixserverlib.AsTimestamp(ts),
|
||||
}
|
||||
err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
|
||||
err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification)
|
||||
assert.NoError(t, err, "unable to insert notification")
|
||||
}
|
||||
|
||||
// get notifications
|
||||
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
|
||||
count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications)
|
||||
assert.NoError(t, err, "unable to get notification count")
|
||||
assert.Equal(t, int64(10), count)
|
||||
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
|
||||
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications)
|
||||
assert.NoError(t, err, "unable to get notifications")
|
||||
assert.Equal(t, int64(10), count)
|
||||
assert.Equal(t, 10, len(notifs))
|
||||
// ... for a specific room
|
||||
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(4), total)
|
||||
|
||||
// mark notification as read
|
||||
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
|
||||
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true)
|
||||
assert.NoError(t, err, "unable to set notifications read")
|
||||
assert.True(t, affected)
|
||||
|
||||
// this should delete 2 notifications
|
||||
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
|
||||
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8)
|
||||
assert.NoError(t, err, "unable to set notifications read")
|
||||
assert.True(t, affected)
|
||||
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(2), total)
|
||||
|
||||
|
|
@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
// this should now return 0 notifications
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(0), total)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -28,31 +28,31 @@ import (
|
|||
)
|
||||
|
||||
type AccountDataTable interface {
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||
}
|
||||
|
||||
type AccountsTable interface {
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
|
||||
}
|
||||
|
||||
type DevicesTable interface {
|
||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
||||
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
|
||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
|
||||
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
||||
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
|
||||
SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
|
||||
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||
}
|
||||
|
||||
type KeyBackupTable interface {
|
||||
|
|
@ -79,40 +79,40 @@ type LoginTokenTable interface {
|
|||
}
|
||||
|
||||
type OpenIDTable interface {
|
||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
|
||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
|
||||
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
}
|
||||
|
||||
type ProfileTable interface {
|
||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
||||
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
}
|
||||
|
||||
type ThreePIDTable interface {
|
||||
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
|
||||
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
|
||||
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||
SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
||||
}
|
||||
|
||||
type PusherTable interface {
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||
}
|
||||
|
||||
type NotificationTable interface {
|
||||
Clean(ctx context.Context, txn *sql.Tx) error
|
||||
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
||||
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
|
||||
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
|
||||
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
||||
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
|
||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
||||
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error)
|
||||
UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error)
|
||||
Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
||||
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error)
|
||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||
}
|
||||
|
||||
type StatsTable interface {
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
|
|||
accDB tables.AccountsTable,
|
||||
devDB tables.DevicesTable,
|
||||
localpart string,
|
||||
serverName gomatrixserverlib.ServerName, // nolint:unparam
|
||||
accType api.AccountType,
|
||||
userAgent string,
|
||||
) {
|
||||
|
|
@ -89,11 +90,11 @@ func mustMakeAccountAndDevice(
|
|||
appServiceID = util.RandomString(16)
|
||||
}
|
||||
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create account: %v", err)
|
||||
}
|
||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
|
||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create device: %v", err)
|
||||
}
|
||||
|
|
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Want Users", func(t *testing.T) {
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
|
||||
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
|
|
|
|||
|
|
@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) {
|
|||
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
|
||||
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
|
||||
t.Fatalf("failed to set avatar url: %s", err)
|
||||
}
|
||||
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
|
||||
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
|
||||
t.Fatalf("failed to set display name: %s", err)
|
||||
}
|
||||
|
||||
|
|
@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService)
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
|
@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ package util
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -17,10 +19,10 @@ type PusherDevice struct {
|
|||
}
|
||||
|
||||
// GetPushDevices pushes to the configured devices of a local user.
|
||||
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart)
|
||||
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("db.GetPushers: %w", err)
|
||||
}
|
||||
|
||||
devices := make([]*PusherDevice, 0, len(pushers))
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -16,8 +17,8 @@ import (
|
|||
// a single goroutine is started when talking to the Push
|
||||
// gateways. There is no way to know when the background goroutine has
|
||||
// finished.
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc
|
|||
return nil
|
||||
}
|
||||
|
||||
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
|
||||
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue