diff --git a/go.mod b/go.mod index 22b01be3c..360ddcfc1 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/docker/go-connections v0.4.0 github.com/getsentry/sentry-go v0.11.0 github.com/gologme/log v1.2.0 + github.com/google/go-cmp v0.5.5 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.1 // indirect diff --git a/go.sum b/go.sum index a7da02c5a..ce1a0a30a 100644 --- a/go.sum +++ b/go.sum @@ -494,6 +494,7 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go new file mode 100644 index 000000000..9b7e77b35 --- /dev/null +++ b/internal/pushgateway/client.go @@ -0,0 +1,64 @@ +package pushgateway + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/opentracing/opentracing-go" +) + +type httpClient struct { + hc *http.Client +} + +// NewHTTPClient creates a new Push Gateway client. +func NewHTTPClient(disableTLSValidation bool) Client { + hc := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: disableTLSValidation, + }, + }, + } + return &httpClient{hc: hc} +} + +func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "Notify") + defer span.Finish() + + body, err := json.Marshal(req) + if err != nil { + return err + } + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + hreq.Header.Set("Content-Type", "application/json") + + hresp, err := h.hc.Do(hreq) + if err != nil { + return err + } + defer hresp.Body.Close() + + if hresp.StatusCode == http.StatusOK { + return json.NewDecoder(hresp.Body).Decode(resp) + } + + var errorBody struct { + Message string `json:"message"` + } + if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil { + return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message) + } + return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url) +} diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go new file mode 100644 index 000000000..1960acac5 --- /dev/null +++ b/internal/pushgateway/pushgateway.go @@ -0,0 +1,62 @@ +package pushgateway + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A Client is how interactions iwth a Push Gateway is done. +type Client interface { + // Notify sends a notification to the gateway at the given URL. + Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error +} + +type NotifyRequest struct { + Notification Notification `json:"notification"` // Required +} + +type NotifyResponse struct { + // Rejected is the list of device push keys that were rejected + // during the push. The caller should remove the push keys so they + // are not used again. + Rejected []string `json:"rejected"` // Required +} + +type Notification struct { + Content json.RawMessage `json:"content,omitempty"` + Counts *Counts `json:"counts,omitempty"` + Devices []*Device `json:"devices"` // Required + EventID string `json:"event_id,omitempty"` + ID string `json:"id,omitempty"` // Deprecated name for EventID. + Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest. + Prio Prio `json:"prio,omitempty"` + RoomAlias string `json:"room_alias,omitempty"` + RoomID string `json:"room_id,omitempty"` + RoomName string `json:"room_name,omitempty"` + Sender string `json:"sender,omitempty"` + SenderDisplayName string `json:"sender_display_name,omitempty"` + Type string `json:"type,omitempty"` + UserIsTarget bool `json:"user_is_target,omitempty"` +} + +type Counts struct { + MissedCalls int `json:"missed_calls,omitempty"` + Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it. +} + +type Device struct { + AppID string `json:"app_id"` // Required + Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. + PushKey string `json:"pushkey"` // Required + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Tweaks map[string]interface{} `json:"tweaks,omitempty"` +} + +type Prio string + +const ( + HighPrio Prio = "high" + LowPrio Prio = "low" +)