Update gomatrixserverlib version (#476)

Signed-off-by: Andrew Morgan <andrewm@matrix.org>
This commit is contained in:
Andrew Morgan 2018-06-01 17:42:55 +01:00 committed by GitHub
parent 63dc2141ba
commit 241b1b5ace
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 662 additions and 121 deletions

View file

@ -52,15 +52,15 @@ func (d Database) FetcherName() string {
// FetchKeys implements gomatrixserverlib.KeyDatabase // FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) FetchKeys( func (d *Database) FetchKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult, error) { ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
return d.statements.bulkSelectServerKeys(ctx, requests) return d.statements.bulkSelectServerKeys(ctx, requests)
} }
// StoreKeys implements gomatrixserverlib.KeyDatabase // StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) StoreKeys( func (d *Database) StoreKeys(
ctx context.Context, ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult, keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
// TODO: Inserting all the keys within a single transaction may // TODO: Inserting all the keys within a single transaction may
// be more efficient since the transaction overhead can be quite // be more efficient since the transaction overhead can be quite

View file

@ -79,8 +79,8 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
func (s *serverKeyStatements) bulkSelectServerKeys( func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult, error) { ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string var nameAndKeyIDs []string
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
@ -91,7 +91,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
results := map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult{} results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() { for rows.Next() {
var serverName string var serverName string
var keyID string var keyID string
@ -101,7 +101,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err return nil, err
} }
r := gomatrixserverlib.PublicKeyRequest{ r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName), ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
} }
@ -121,7 +121,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
func (s *serverKeyStatements) upsertServerKeys( func (s *serverKeyStatements) upsertServerKeys(
ctx context.Context, ctx context.Context,
request gomatrixserverlib.PublicKeyRequest, request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult, key gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
_, err := s.upsertServerKeysStmt.ExecContext( _, err := s.upsertServerKeysStmt.ExecContext(
@ -136,6 +136,6 @@ func (s *serverKeyStatements) upsertServerKeys(
return err return err
} }
func nameAndKeyID(request gomatrixserverlib.PublicKeyRequest) string { func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
return string(request.ServerName) + "\x1F" + string(request.KeyID) return string(request.ServerName) + "\x1F" + string(request.KeyID)
} }

2
vendor/manifest vendored
View file

@ -135,7 +135,7 @@
{ {
"importpath": "github.com/matrix-org/gomatrixserverlib", "importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib", "repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "afa71391f946312c40639a419045e06b8ff2309a", "revision": "38a4f0f648bf357adc4bdb601cdc0535cee14e21",
"branch": "master" "branch": "master"
}, },
{ {

View file

@ -0,0 +1,35 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
// ApplicationServiceEvent is an event format that is sent off to an
// application service as part of a transaction.
type ApplicationServiceEvent struct {
Age int64 `json:"age,omitempty"`
Content RawJSON `json:"content,omitempty"`
EventID string `json:"event_id,omitempty"`
OriginServerTimestamp int64 `json:"origin_server_ts,omitempty"`
RoomID string `json:"room_id,omitempty"`
Sender string `json:"sender,omitempty"`
Type string `json:"type,omitempty"`
UserID string `json:"user_id,omitempty"`
}
// ApplicationServiceTransaction is the transaction that is sent off to an
// application service.
type ApplicationServiceTransaction struct {
Events []ApplicationServiceEvent `json:"events"`
}

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// Default HTTPS request timeout // Default HTTPS request timeout
@ -207,7 +208,7 @@ func (fc *Client) GetServerKeys(
// copy of the keys. // copy of the keys.
// Returns the keys returned by the server, or an error if there was a problem talking to the server. // Returns the keys returned by the server, or an error if there was a problem talking to the server.
func (fc *Client) LookupServerKeys( func (fc *Client) LookupServerKeys(
ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp, ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyLookupRequest]Timestamp,
) ([]ServerKeys, error) { ) ([]ServerKeys, error) {
url := url.URL{ url := url.URL{
Scheme: "matrix", Scheme: "matrix",
@ -332,18 +333,26 @@ func (fc *Client) DoRequestAndParseResponse(
// //
func (fc *Client) DoHTTPRequest(ctx context.Context, req *http.Request) (*http.Response, error) { func (fc *Client) DoHTTPRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
reqID := util.RandomString(12) reqID := util.RandomString(12)
logger := util.GetLogger(ctx).WithField("server", req.URL.Host).WithField("out.req.ID", reqID) logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"out.req.ID": reqID,
"out.req.method": req.Method,
"out.req.uri": req.URL,
})
logger.Info("Outgoing request")
newCtx := util.ContextWithLogger(ctx, logger) newCtx := util.ContextWithLogger(ctx, logger)
logger.Infof("Outgoing request %s %s", req.Method, req.URL) start := time.Now()
resp, err := fc.client.Do(req.WithContext(newCtx)) resp, err := fc.client.Do(req.WithContext(newCtx))
if err != nil { if err != nil {
logger.Infof("Outgoing request %s %s failed with %v", req.Method, req.URL, err) logger.WithField("error", err).Warn("Outgoing request failed")
return nil, err return nil, err
} }
// we haven't yet read the body, so this is slightly premature, but it's the easiest place. // we haven't yet read the body, so this is slightly premature, but it's the easiest place.
logger.Infof("Response %d from %s %s", resp.StatusCode, req.Method, req.URL) logger.WithFields(logrus.Fields{
"out.req.code": resp.StatusCode,
"out.req.duration_ms": int(time.Since(start) / time.Millisecond),
}).Info("Outgoing request returned")
return resp, nil return resp, nil
} }

View file

@ -28,7 +28,7 @@ const (
// ClientEvent is an event which is fit for consumption by clients, in accordance with the specification. // ClientEvent is an event which is fit for consumption by clients, in accordance with the specification.
type ClientEvent struct { type ClientEvent struct {
Content rawJSON `json:"content"` Content RawJSON `json:"content"`
EventID string `json:"event_id"` EventID string `json:"event_id"`
OriginServerTS Timestamp `json:"origin_server_ts"` OriginServerTS Timestamp `json:"origin_server_ts"`
// RoomID is omitted on /sync responses // RoomID is omitted on /sync responses
@ -36,7 +36,7 @@ type ClientEvent struct {
Sender string `json:"sender"` Sender string `json:"sender"`
StateKey *string `json:"state_key,omitempty"` StateKey *string `json:"state_key,omitempty"`
Type string `json:"type"` Type string `json:"type"`
Unsigned rawJSON `json:"unsigned,omitempty"` Unsigned RawJSON `json:"unsigned,omitempty"`
} }
// ToClientEvents converts server events to client events. // ToClientEvents converts server events to client events.
@ -51,11 +51,11 @@ func ToClientEvents(serverEvs []Event, format EventFormat) []ClientEvent {
// ToClientEvent converts a single server event to a client event. // ToClientEvent converts a single server event to a client event.
func ToClientEvent(se Event, format EventFormat) ClientEvent { func ToClientEvent(se Event, format EventFormat) ClientEvent {
ce := ClientEvent{ ce := ClientEvent{
Content: rawJSON(se.Content()), Content: RawJSON(se.Content()),
Sender: se.Sender(), Sender: se.Sender(),
Type: se.Type(), Type: se.Type(),
StateKey: se.StateKey(), StateKey: se.StateKey(),
Unsigned: rawJSON(se.Unsigned()), Unsigned: RawJSON(se.Unsigned()),
OriginServerTS: se.OriginServerTS(), OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(), EventID: se.EventID(),
} }

View file

@ -69,9 +69,9 @@ type EventBuilder struct {
// The create event has a depth of 1. // The create event has a depth of 1.
Depth int64 `json:"depth"` Depth int64 `json:"depth"`
// The JSON object for "content" key of the event. // The JSON object for "content" key of the event.
Content rawJSON `json:"content"` Content RawJSON `json:"content"`
// The JSON object for the "unsigned" key // The JSON object for the "unsigned" key
Unsigned rawJSON `json:"unsigned,omitempty"` Unsigned RawJSON `json:"unsigned,omitempty"`
} }
// SetContent sets the JSON content key of the event. // SetContent sets the JSON content key of the event.
@ -102,12 +102,12 @@ type eventFields struct {
Sender string `json:"sender"` Sender string `json:"sender"`
Type string `json:"type"` Type string `json:"type"`
StateKey *string `json:"state_key"` StateKey *string `json:"state_key"`
Content rawJSON `json:"content"` Content RawJSON `json:"content"`
PrevEvents []EventReference `json:"prev_events"` PrevEvents []EventReference `json:"prev_events"`
AuthEvents []EventReference `json:"auth_events"` AuthEvents []EventReference `json:"auth_events"`
Redacts string `json:"redacts"` Redacts string `json:"redacts"`
Depth int64 `json:"depth"` Depth int64 `json:"depth"`
Unsigned rawJSON `json:"unsigned"` Unsigned RawJSON `json:"unsigned"`
OriginServerTS Timestamp `json:"origin_server_ts"` OriginServerTS Timestamp `json:"origin_server_ts"`
Origin ServerName `json:"origin"` Origin ServerName `json:"origin"`
} }
@ -284,7 +284,7 @@ func (e Event) Redact() Event {
// SetUnsigned sets the unsigned key of the event. // SetUnsigned sets the unsigned key of the event.
// Returns a copy of the event with the "unsigned" key set. // Returns a copy of the event with the "unsigned" key set.
func (e Event) SetUnsigned(unsigned interface{}) (Event, error) { func (e Event) SetUnsigned(unsigned interface{}) (Event, error) {
var eventAsMap map[string]rawJSON var eventAsMap map[string]RawJSON
var err error var err error
if err = json.Unmarshal(e.eventJSON, &eventAsMap); err != nil { if err = json.Unmarshal(e.eventJSON, &eventAsMap); err != nil {
return Event{}, err return Event{}, err
@ -326,7 +326,7 @@ func (e *Event) SetUnsignedField(path string, value interface{}) error {
eventJSON = CanonicalJSONAssumeValid(eventJSON) eventJSON = CanonicalJSONAssumeValid(eventJSON)
res := gjson.GetBytes(eventJSON, "unsigned") res := gjson.GetBytes(eventJSON, "unsigned")
unsigned := rawJSONFromResult(res, eventJSON) unsigned := RawJSONFromResult(res, eventJSON)
e.eventJSON = eventJSON e.eventJSON = eventJSON
e.fields.Unsigned = unsigned e.fields.Unsigned = unsigned
@ -617,7 +617,7 @@ func (e Event) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaller // UnmarshalJSON implements json.Unmarshaller
func (er *EventReference) UnmarshalJSON(data []byte) error { func (er *EventReference) UnmarshalJSON(data []byte) error {
var tuple []rawJSON var tuple []RawJSON
if err := json.Unmarshal(data, &tuple); err != nil { if err := json.Unmarshal(data, &tuple); err != nil {
return err return err
} }

View file

@ -52,7 +52,7 @@ func stateNeededEquals(a, b StateNeeded) bool {
type testEventList []Event type testEventList []Event
func (tel *testEventList) UnmarshalJSON(data []byte) error { func (tel *testEventList) UnmarshalJSON(data []byte) error {
var eventJSONs []rawJSON var eventJSONs []RawJSON
var events []Event var events []Event
if err := json.Unmarshal([]byte(data), &eventJSONs); err != nil { if err := json.Unmarshal([]byte(data), &eventJSONs); err != nil {
return err return err
@ -997,7 +997,7 @@ func TestRedactAllowed(t *testing.T) {
} }
func TestAuthEvents(t *testing.T) { func TestAuthEvents(t *testing.T) {
power, err := NewEventFromTrustedJSON(rawJSON(`{ power, err := NewEventFromTrustedJSON(RawJSON(`{
"type": "m.room.power_levels", "type": "m.room.power_levels",
"state_key": "", "state_key": "",
"sender": "@u1:a", "sender": "@u1:a",
@ -1018,7 +1018,7 @@ func TestAuthEvents(t *testing.T) {
if e, err = a.PowerLevels(); err != nil || e != &power { if e, err = a.PowerLevels(); err != nil || e != &power {
t.Errorf("TestAuthEvents: failed to get same power_levels event") t.Errorf("TestAuthEvents: failed to get same power_levels event")
} }
create, err := NewEventFromTrustedJSON(rawJSON(`{ create, err := NewEventFromTrustedJSON(RawJSON(`{
"type": "m.room.create", "type": "m.room.create",
"state_key": "", "state_key": "",
"sender": "@u1:a", "sender": "@u1:a",

View file

@ -31,7 +31,7 @@ import (
// This hash is used to detect whether the unredacted content of the event is valid. // This hash is used to detect whether the unredacted content of the event is valid.
// Returns the event JSON with a "hashes" key added to it. // Returns the event JSON with a "hashes" key added to it.
func addContentHashesToEvent(eventJSON []byte) ([]byte, error) { func addContentHashesToEvent(eventJSON []byte) ([]byte, error) {
var event map[string]rawJSON var event map[string]RawJSON
if err := json.Unmarshal(eventJSON, &event); err != nil { if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, err return nil, err
@ -64,7 +64,7 @@ func addContentHashesToEvent(eventJSON []byte) ([]byte, error) {
if len(unsignedJSON) > 0 { if len(unsignedJSON) > 0 {
event["unsigned"] = unsignedJSON event["unsigned"] = unsignedJSON
} }
event["hashes"] = rawJSON(hashesJSON) event["hashes"] = RawJSON(hashesJSON)
return json.Marshal(event) return json.Marshal(event)
} }
@ -105,7 +105,7 @@ func referenceOfEvent(eventJSON []byte) (EventReference, error) {
return EventReference{}, err return EventReference{}, err
} }
var event map[string]rawJSON var event map[string]RawJSON
if err = json.Unmarshal(redactedJSON, &event); err != nil { if err = json.Unmarshal(redactedJSON, &event); err != nil {
return EventReference{}, err return EventReference{}, err
} }
@ -150,14 +150,14 @@ func signEvent(signingName string, keyID KeyID, privateKey ed25519.PrivateKey, e
} }
var signedEvent struct { var signedEvent struct {
Signatures rawJSON `json:"signatures"` Signatures RawJSON `json:"signatures"`
} }
if err := json.Unmarshal(signedJSON, &signedEvent); err != nil { if err := json.Unmarshal(signedJSON, &signedEvent); err != nil {
return nil, err return nil, err
} }
// Unmarshal the event JSON so that we can replace the signatures key. // Unmarshal the event JSON so that we can replace the signatures key.
var event map[string]rawJSON var event map[string]RawJSON
if err := json.Unmarshal(eventJSON, &event); err != nil { if err := json.Unmarshal(eventJSON, &event); err != nil {
return nil, err return nil, err
} }

View file

@ -195,7 +195,7 @@ func (r RespSendJoin) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaller // UnmarshalJSON implements json.Unmarshaller
func (r *RespSendJoin) UnmarshalJSON(data []byte) error { func (r *RespSendJoin) UnmarshalJSON(data []byte) error {
var tuple []rawJSON var tuple []RawJSON
if err := json.Unmarshal(data, &tuple); err != nil { if err := json.Unmarshal(data, &tuple); err != nil {
return err return err
} }
@ -306,7 +306,7 @@ func (r RespInvite) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaller // UnmarshalJSON implements json.Unmarshaller
func (r *RespInvite) UnmarshalJSON(data []byte) error { func (r *RespInvite) UnmarshalJSON(data []byte) error {
var tuple []rawJSON var tuple []RawJSON
if err := json.Unmarshal(data, &tuple); err != nil { if err := json.Unmarshal(data, &tuple); err != nil {
return err return err
} }

View file

@ -0,0 +1,90 @@
// Copyright 2017 Jan Christian Grünhage
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gomatrixserverlib
import "errors"
// Filter is used by clients to specify how the server should filter responses to e.g. sync requests
// Specified by: https://matrix.org/docs/spec/client_server/r0.2.0.html#filtering
type Filter struct {
AccountData FilterPart `json:"account_data,omitempty"`
EventFields []string `json:"event_fields,omitempty"`
EventFormat string `json:"event_format,omitempty"`
Presence FilterPart `json:"presence,omitempty"`
Room RoomFilter `json:"room,omitempty"`
}
// RoomFilter is used to define filtering rules for room events
type RoomFilter struct {
AccountData FilterPart `json:"account_data,omitempty"`
Ephemeral FilterPart `json:"ephemeral,omitempty"`
IncludeLeave bool `json:"include_leave,omitempty"`
NotRooms []string `json:"not_rooms,omitempty"`
Rooms []string `json:"rooms,omitempty"`
State FilterPart `json:"state,omitempty"`
Timeline FilterPart `json:"timeline,omitempty"`
}
// FilterPart is used to define filtering rules for specific categories of events
type FilterPart struct {
NotRooms []string `json:"not_rooms,omitempty"`
Rooms []string `json:"rooms,omitempty"`
Limit int `json:"limit,omitempty"`
NotSenders []string `json:"not_senders,omitempty"`
NotTypes []string `json:"not_types,omitempty"`
Senders []string `json:"senders,omitempty"`
Types []string `json:"types,omitempty"`
ContainsURL *bool `json:"contains_url,omitempty"`
}
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != "client" && filter.EventFormat != "federation" {
return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
}
return nil
}
// DefaultFilter returns the default filter used by the Matrix server if no filter is provided in the request
func DefaultFilter() Filter {
return Filter{
AccountData: DefaultFilterPart(),
EventFields: nil,
EventFormat: "client",
Presence: DefaultFilterPart(),
Room: RoomFilter{
AccountData: DefaultFilterPart(),
Ephemeral: DefaultFilterPart(),
IncludeLeave: false,
NotRooms: nil,
Rooms: nil,
State: DefaultFilterPart(),
Timeline: DefaultFilterPart(),
},
}
}
// DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request
func DefaultFilterPart() FilterPart {
return FilterPart{
NotRooms: nil,
Rooms: nil,
Limit: 20,
NotSenders: nil,
NotTypes: nil,
Senders: nil,
Types: nil,
}
}

0
vendor/src/github.com/matrix-org/gomatrixserverlib/hooks/install.sh vendored Normal file → Executable file
View file

2
vendor/src/github.com/matrix-org/gomatrixserverlib/hooks/pre-commit vendored Normal file → Executable file
View file

@ -23,7 +23,7 @@ git checkout-index -a
echo "Installing lint search engine..." echo "Installing lint search engine..."
go get github.com/alecthomas/gometalinter/ go get github.com/alecthomas/gometalinter/
gometalinter --config=linter.json --install --update gometalinter --config=linter.json --install --update --debug
echo "Testing..." echo "Testing..."
go test go test

View file

@ -17,10 +17,10 @@ package gomatrixserverlib
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"sort" "sort"
"unicode/utf8" "unicode/utf8"
"github.com/pkg/errors"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@ -29,7 +29,7 @@ import (
// https://matrix.org/docs/spec/server_server/unstable.html#canonical-json // https://matrix.org/docs/spec/server_server/unstable.html#canonical-json
func CanonicalJSON(input []byte) ([]byte, error) { func CanonicalJSON(input []byte) ([]byte, error) {
if !gjson.Valid(string(input)) { if !gjson.Valid(string(input)) {
return nil, errors.Errorf("invalid json") return nil, fmt.Errorf("invalid json")
} }
return CanonicalJSONAssumeValid(input), nil return CanonicalJSONAssumeValid(input), nil
@ -47,8 +47,8 @@ func CanonicalJSONAssumeValid(input []byte) []byte {
func SortJSON(input, output []byte) []byte { func SortJSON(input, output []byte) []byte {
result := gjson.ParseBytes(input) result := gjson.ParseBytes(input)
rawJSON := rawJSONFromResult(result, input) RawJSON := RawJSONFromResult(result, input)
return sortJSONValue(result, rawJSON, output) return sortJSONValue(result, RawJSON, output)
} }
// sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the // sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the
@ -77,8 +77,8 @@ func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {
output = append(output, sep) output = append(output, sep)
sep = ',' sep = ','
rawJSON := rawJSONFromResult(value, inputJSON) RawJSON := RawJSONFromResult(value, inputJSON)
output = sortJSONValue(value, rawJSON, output) output = sortJSONValue(value, RawJSON, output)
return true // keep iterating return true // keep iterating
}) })
@ -110,7 +110,7 @@ func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
input.ForEach(func(key, value gjson.Result) bool { input.ForEach(func(key, value gjson.Result) bool {
entries = append(entries, entry{ entries = append(entries, entry{
key: key.String(), key: key.String(),
rawKey: rawJSONFromResult(key, inputJSON), rawKey: RawJSONFromResult(key, inputJSON),
value: value, value: value,
}) })
return true // keep iterating return true // keep iterating
@ -131,9 +131,9 @@ func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
output = append(output, entry.rawKey...) output = append(output, entry.rawKey...)
output = append(output, ':') output = append(output, ':')
rawJSON := rawJSONFromResult(entry.value, inputJSON) RawJSON := RawJSONFromResult(entry.value, inputJSON)
output = sortJSONValue(entry.value, rawJSON, output) output = sortJSONValue(entry.value, RawJSON, output)
} }
if sep == '{' { if sep == '{' {
// If sep is still '{' then the object was empty and we never wrote the // If sep is still '{' then the object was empty and we never wrote the
@ -263,17 +263,17 @@ func readHexDigits(input []byte) uint32 {
return hex & 0xFFFF return hex & 0xFFFF
} }
// rawJSONFromResult extracts the raw JSON bytes pointed to by result. // RawJSONFromResult extracts the raw JSON bytes pointed to by result.
// input must be the json bytes that were used to generate result // input must be the json bytes that were used to generate result
func rawJSONFromResult(result gjson.Result, input []byte) (rawJSON []byte) { func RawJSONFromResult(result gjson.Result, input []byte) (RawJSON []byte) {
// This is lifted from gjson README. Basically, result.Raw is a copy of // This is lifted from gjson README. Basically, result.Raw is a copy of
// the bytes we want, but its more efficient to take a slice. // the bytes we want, but its more efficient to take a slice.
// If Index is 0 then for some reason we can't extract it from the original // If Index is 0 then for some reason we can't extract it from the original
// JSON bytes. // JSON bytes.
if result.Index > 0 { if result.Index > 0 {
rawJSON = input[result.Index : result.Index+len(result.Raw)] RawJSON = input[result.Index : result.Index+len(result.Raw)]
} else { } else {
rawJSON = []byte(result.Raw) RawJSON = []byte(result.Raw)
} }
return return

View file

@ -10,8 +10,8 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
// A PublicKeyRequest is a request for a public key with a particular key ID. // A PublicKeyLookupRequest is a request for a public key with a particular key ID.
type PublicKeyRequest struct { type PublicKeyLookupRequest struct {
// The server to fetch a key for. // The server to fetch a key for.
ServerName ServerName ServerName ServerName
// The ID of the key to fetch. // The ID of the key to fetch.
@ -60,7 +60,7 @@ type KeyFetcher interface {
// The result may have fewer (server name, key ID) pairs than were in the request. // The result may have fewer (server name, key ID) pairs than were in the request.
// The result may have more (server name, key ID) pairs than were in the request. // The result may have more (server name, key ID) pairs than were in the request.
// Returns an error if there was a problem fetching the keys. // Returns an error if there was a problem fetching the keys.
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]PublicKeyLookupResult, error) FetchKeys(ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error)
// FetcherName returns the name of this fetcher, which can then be used for // FetcherName returns the name of this fetcher, which can then be used for
// logging errors etc. // logging errors etc.
@ -77,7 +77,7 @@ type KeyDatabase interface {
// to a concurrent FetchKeys(). This is acceptable since the database is // to a concurrent FetchKeys(). This is acceptable since the database is
// only used as a cache for the keys, so if a FetchKeys() races with a // only used as a cache for the keys, so if a FetchKeys() races with a
// StoreKeys() and some of the keys are missing they will be just be refetched. // StoreKeys() and some of the keys are missing they will be just be refetched.
StoreKeys(ctx context.Context, results map[PublicKeyRequest]PublicKeyLookupResult) error StoreKeys(ctx context.Context, results map[PublicKeyLookupRequest]PublicKeyLookupResult) error
} }
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages. // A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
@ -202,15 +202,15 @@ func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
func (k *KeyRing) publicKeyRequests( func (k *KeyRing) publicKeyRequests(
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID, requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
) map[PublicKeyRequest]Timestamp { ) map[PublicKeyLookupRequest]Timestamp {
keyRequests := map[PublicKeyRequest]Timestamp{} keyRequests := map[PublicKeyLookupRequest]Timestamp{}
for i := range requests { for i := range requests {
if results[i].Error == nil { if results[i].Error == nil {
// We've already verified this message, we don't need to refetch the keys for it. // We've already verified this message, we don't need to refetch the keys for it.
continue continue
} }
for _, keyID := range keyIDs[i] { for _, keyID := range keyIDs[i] {
k := PublicKeyRequest{requests[i].ServerName, keyID} k := PublicKeyLookupRequest{requests[i].ServerName, keyID}
// Grab the maximum neeeded TS for this server and key ID. // Grab the maximum neeeded TS for this server and key ID.
// This will default to 0 if the server and keyID weren't in the map. // This will default to 0 if the server and keyID weren't in the map.
maxTS := keyRequests[k] maxTS := keyRequests[k]
@ -228,7 +228,7 @@ func (k *KeyRing) publicKeyRequests(
func (k *KeyRing) checkUsingKeys( func (k *KeyRing) checkUsingKeys(
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID, requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
keys map[PublicKeyRequest]PublicKeyLookupResult, keys map[PublicKeyLookupRequest]PublicKeyLookupResult,
) { ) {
for i := range requests { for i := range requests {
if results[i].Error == nil { if results[i].Error == nil {
@ -237,7 +237,7 @@ func (k *KeyRing) checkUsingKeys(
continue continue
} }
for _, keyID := range keyIDs[i] { for _, keyID := range keyIDs[i] {
serverKey, ok := keys[PublicKeyRequest{requests[i].ServerName, keyID}] serverKey, ok := keys[PublicKeyLookupRequest{requests[i].ServerName, keyID}]
if !ok { if !ok {
// No key for this key ID so we continue onto the next key ID. // No key for this key ID so we continue onto the next key ID.
continue continue
@ -282,14 +282,14 @@ func (p PerspectiveKeyFetcher) FetcherName() string {
// FetchKeys implements KeyFetcher // FetchKeys implements KeyFetcher
func (p *PerspectiveKeyFetcher) FetchKeys( func (p *PerspectiveKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp, ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) { ) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
serverKeys, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests) serverKeys, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
if err != nil { if err != nil {
return nil, err return nil, err
} }
results := map[PublicKeyRequest]PublicKeyLookupResult{} results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
for _, keys := range serverKeys { for _, keys := range serverKeys {
var valid bool var valid bool
@ -347,19 +347,19 @@ func (d DirectKeyFetcher) FetcherName() string {
// FetchKeys implements KeyFetcher // FetchKeys implements KeyFetcher
func (d *DirectKeyFetcher) FetchKeys( func (d *DirectKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp, ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) { ) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{} byServer := map[ServerName]map[PublicKeyLookupRequest]Timestamp{}
for req, ts := range requests { for req, ts := range requests {
server := byServer[req.ServerName] server := byServer[req.ServerName]
if server == nil { if server == nil {
server = map[PublicKeyRequest]Timestamp{} server = map[PublicKeyLookupRequest]Timestamp{}
byServer[req.ServerName] = server byServer[req.ServerName] = server
} }
server[req] = ts server[req] = ts
} }
results := map[PublicKeyRequest]PublicKeyLookupResult{} results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
for server := range byServer { for server := range byServer {
// TODO: make these requests in parallel // TODO: make these requests in parallel
serverResults, err := d.fetchKeysForServer(ctx, server) serverResults, err := d.fetchKeysForServer(ctx, server)
@ -376,7 +376,7 @@ func (d *DirectKeyFetcher) FetchKeys(
func (d *DirectKeyFetcher) fetchKeysForServer( func (d *DirectKeyFetcher) fetchKeysForServer(
ctx context.Context, serverName ServerName, ctx context.Context, serverName ServerName,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) { ) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
keys, err := d.Client.GetServerKeys(ctx, serverName) keys, err := d.Client.GetServerKeys(ctx, serverName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -387,7 +387,7 @@ func (d *DirectKeyFetcher) fetchKeysForServer(
return nil, fmt.Errorf("gomatrixserverlib: key response direct from %q failed checks", serverName) return nil, fmt.Errorf("gomatrixserverlib: key response direct from %q failed checks", serverName)
} }
results := map[PublicKeyRequest]PublicKeyLookupResult{} results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
// TODO (matrix-org/dendrite#345): What happens if the same key ID // TODO (matrix-org/dendrite#345): What happens if the same key ID
// appears in multiple responses? We should probably reject the response. // appears in multiple responses? We should probably reject the response.
@ -397,11 +397,11 @@ func (d *DirectKeyFetcher) fetchKeysForServer(
} }
// mapServerKeysToPublicKeyLookupResult takes the (verified) result from a // mapServerKeysToPublicKeyLookupResult takes the (verified) result from a
// /key/v2/query call and inserts it into a PublicKeyRequest->PublicKeyLookupResult // /key/v2/query call and inserts it into a PublicKeyLookupRequest->PublicKeyLookupResult
// map. // map.
func mapServerKeysToPublicKeyLookupResult(serverKeys ServerKeys, results map[PublicKeyRequest]PublicKeyLookupResult) { func mapServerKeysToPublicKeyLookupResult(serverKeys ServerKeys, results map[PublicKeyLookupRequest]PublicKeyLookupResult) {
for keyID, key := range serverKeys.VerifyKeys { for keyID, key := range serverKeys.VerifyKeys {
results[PublicKeyRequest{ results[PublicKeyLookupRequest{
ServerName: serverKeys.ServerName, ServerName: serverKeys.ServerName,
KeyID: keyID, KeyID: keyID,
}] = PublicKeyLookupResult{ }] = PublicKeyLookupResult{
@ -411,7 +411,7 @@ func mapServerKeysToPublicKeyLookupResult(serverKeys ServerKeys, results map[Pub
} }
} }
for keyID, key := range serverKeys.OldVerifyKeys { for keyID, key := range serverKeys.OldVerifyKeys {
results[PublicKeyRequest{ results[PublicKeyLookupRequest{
ServerName: serverKeys.ServerName, ServerName: serverKeys.ServerName,
KeyID: keyID, KeyID: keyID,
}] = PublicKeyLookupResult{ }] = PublicKeyLookupResult{

View file

@ -41,12 +41,12 @@ func (db testKeyDatabase) FetcherName() string {
} }
func (db *testKeyDatabase) FetchKeys( func (db *testKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp, ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) { ) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
results := map[PublicKeyRequest]PublicKeyLookupResult{} results := map[PublicKeyLookupRequest]PublicKeyLookupResult{}
req1 := PublicKeyRequest{"localhost:8800", "ed25519:old"} req1 := PublicKeyLookupRequest{"localhost:8800", "ed25519:old"}
req2 := PublicKeyRequest{"localhost:8800", "ed25519:a_Obwu"} req2 := PublicKeyLookupRequest{"localhost:8800", "ed25519:a_Obwu"}
for req := range requests { for req := range requests {
if req == req1 { if req == req1 {
@ -79,7 +79,7 @@ func (db *testKeyDatabase) FetchKeys(
} }
func (db *testKeyDatabase) StoreKeys( func (db *testKeyDatabase) StoreKeys(
ctx context.Context, requests map[PublicKeyRequest]PublicKeyLookupResult, ctx context.Context, requests map[PublicKeyLookupRequest]PublicKeyLookupResult,
) error { ) error {
return nil return nil
} }
@ -161,13 +161,13 @@ func (e erroringKeyDatabase) FetcherName() string {
} }
func (e *erroringKeyDatabase) FetchKeys( func (e *erroringKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp, ctx context.Context, requests map[PublicKeyLookupRequest]Timestamp,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) { ) (map[PublicKeyLookupRequest]PublicKeyLookupResult, error) {
return nil, &testErrorFetch return nil, &testErrorFetch
} }
func (e *erroringKeyDatabase) StoreKeys( func (e *erroringKeyDatabase) StoreKeys(
ctx context.Context, keys map[PublicKeyRequest]PublicKeyLookupResult, ctx context.Context, keys map[PublicKeyLookupRequest]PublicKeyLookupResult,
) error { ) error {
return &testErrorStore return &testErrorStore
} }

View file

@ -19,16 +19,16 @@ import (
"encoding/json" "encoding/json"
) )
// rawJSON is a reimplementation of json.RawMessage that supports being used as a value type // RawJSON is a reimplementation of json.RawMessage that supports being used as a value type
// //
// For example: // For example:
// //
// jsonBytes, _ := json.Marshal(struct{ // jsonBytes, _ := json.Marshal(struct{
// RawMessage json.RawMessage // RawMessage json.RawMessage
// RawJSON rawJSON // RawJSON RawJSON
// }{ // }{
// json.RawMessage(`"Hello"`), // json.RawMessage(`"Hello"`),
// rawJSON(`"World"`), // RawJSON(`"World"`),
// }) // })
// //
// Results in: // Results in:
@ -36,17 +36,17 @@ import (
// {"RawMessage":"IkhlbGxvIg==","RawJSON":"World"} // {"RawMessage":"IkhlbGxvIg==","RawJSON":"World"}
// //
// See https://play.golang.org/p/FzhKIJP8-I for a full example. // See https://play.golang.org/p/FzhKIJP8-I for a full example.
type rawJSON []byte type RawJSON []byte
// MarshalJSON implements the json.Marshaller interface using a value receiver. // MarshalJSON implements the json.Marshaller interface using a value receiver.
// This means that rawJSON used as an embedded value will still encode correctly. // This means that RawJSON used as an embedded value will still encode correctly.
func (r rawJSON) MarshalJSON() ([]byte, error) { func (r RawJSON) MarshalJSON() ([]byte, error) {
return []byte(r), nil return []byte(r), nil
} }
// UnmarshalJSON implements the json.Unmarshaller interface using a pointer receiver. // UnmarshalJSON implements the json.Unmarshaller interface using a pointer receiver.
func (r *rawJSON) UnmarshalJSON(data []byte) error { func (r *RawJSON) UnmarshalJSON(data []byte) error {
*r = rawJSON(data) *r = RawJSON(data)
return nil return nil
} }
@ -58,45 +58,45 @@ func redactEvent(eventJSON []byte) ([]byte, error) {
// Create events need to keep the creator. // Create events need to keep the creator.
// (In an ideal world they would keep the m.federate flag see matrix-org/synapse#1831) // (In an ideal world they would keep the m.federate flag see matrix-org/synapse#1831)
type createContent struct { type createContent struct {
Creator rawJSON `json:"creator,omitempty"` Creator RawJSON `json:"creator,omitempty"`
} }
// joinRulesContent keeps the fields needed in a m.room.join_rules event. // joinRulesContent keeps the fields needed in a m.room.join_rules event.
// Join rules events need to keep the join_rule key. // Join rules events need to keep the join_rule key.
type joinRulesContent struct { type joinRulesContent struct {
JoinRule rawJSON `json:"join_rule,omitempty"` JoinRule RawJSON `json:"join_rule,omitempty"`
} }
// powerLevelContent keeps the fields needed in a m.room.power_levels event. // powerLevelContent keeps the fields needed in a m.room.power_levels event.
// Power level events need to keep all the levels. // Power level events need to keep all the levels.
type powerLevelContent struct { type powerLevelContent struct {
Users rawJSON `json:"users,omitempty"` Users RawJSON `json:"users,omitempty"`
UsersDefault rawJSON `json:"users_default,omitempty"` UsersDefault RawJSON `json:"users_default,omitempty"`
Events rawJSON `json:"events,omitempty"` Events RawJSON `json:"events,omitempty"`
EventsDefault rawJSON `json:"events_default,omitempty"` EventsDefault RawJSON `json:"events_default,omitempty"`
StateDefault rawJSON `json:"state_default,omitempty"` StateDefault RawJSON `json:"state_default,omitempty"`
Ban rawJSON `json:"ban,omitempty"` Ban RawJSON `json:"ban,omitempty"`
Kick rawJSON `json:"kick,omitempty"` Kick RawJSON `json:"kick,omitempty"`
Redact rawJSON `json:"redact,omitempty"` Redact RawJSON `json:"redact,omitempty"`
} }
// memberContent keeps the fields needed in a m.room.member event. // memberContent keeps the fields needed in a m.room.member event.
// Member events keep the membership. // Member events keep the membership.
// (In an ideal world they would keep the third_party_invite see matrix-org/synapse#1831) // (In an ideal world they would keep the third_party_invite see matrix-org/synapse#1831)
type memberContent struct { type memberContent struct {
Membership rawJSON `json:"membership,omitempty"` Membership RawJSON `json:"membership,omitempty"`
} }
// aliasesContent keeps the fields needed in a m.room.aliases event. // aliasesContent keeps the fields needed in a m.room.aliases event.
// TODO: Alias events probably don't need to keep the aliases key, but we need to match synapse here. // TODO: Alias events probably don't need to keep the aliases key, but we need to match synapse here.
type aliasesContent struct { type aliasesContent struct {
Aliases rawJSON `json:"aliases,omitempty"` Aliases RawJSON `json:"aliases,omitempty"`
} }
// historyVisibilityContent keeps the fields needed in a m.room.history_visibility event // historyVisibilityContent keeps the fields needed in a m.room.history_visibility event
// History visibility events need to keep the history_visibility key. // History visibility events need to keep the history_visibility key.
type historyVisibilityContent struct { type historyVisibilityContent struct {
HistoryVisibility rawJSON `json:"history_visibility,omitempty"` HistoryVisibility RawJSON `json:"history_visibility,omitempty"`
} }
// allContent keeps the union of all the content fields needed across all the event types. // allContent keeps the union of all the content fields needed across all the event types.
@ -114,21 +114,21 @@ func redactEvent(eventJSON []byte) ([]byte, error) {
// (In an ideal world they would include the "redacts" key for m.room.redaction events, see matrix-org/synapse#1831) // (In an ideal world they would include the "redacts" key for m.room.redaction events, see matrix-org/synapse#1831)
// See https://github.com/matrix-org/synapse/blob/v0.18.7/synapse/events/utils.py#L42-L56 for the list of fields // See https://github.com/matrix-org/synapse/blob/v0.18.7/synapse/events/utils.py#L42-L56 for the list of fields
type eventFields struct { type eventFields struct {
EventID rawJSON `json:"event_id,omitempty"` EventID RawJSON `json:"event_id,omitempty"`
Sender rawJSON `json:"sender,omitempty"` Sender RawJSON `json:"sender,omitempty"`
RoomID rawJSON `json:"room_id,omitempty"` RoomID RawJSON `json:"room_id,omitempty"`
Hashes rawJSON `json:"hashes,omitempty"` Hashes RawJSON `json:"hashes,omitempty"`
Signatures rawJSON `json:"signatures,omitempty"` Signatures RawJSON `json:"signatures,omitempty"`
Content allContent `json:"content"` Content allContent `json:"content"`
Type string `json:"type"` Type string `json:"type"`
StateKey rawJSON `json:"state_key,omitempty"` StateKey RawJSON `json:"state_key,omitempty"`
Depth rawJSON `json:"depth,omitempty"` Depth RawJSON `json:"depth,omitempty"`
PrevEvents rawJSON `json:"prev_events,omitempty"` PrevEvents RawJSON `json:"prev_events,omitempty"`
PrevState rawJSON `json:"prev_state,omitempty"` PrevState RawJSON `json:"prev_state,omitempty"`
AuthEvents rawJSON `json:"auth_events,omitempty"` AuthEvents RawJSON `json:"auth_events,omitempty"`
Origin rawJSON `json:"origin,omitempty"` Origin RawJSON `json:"origin,omitempty"`
OriginServerTS rawJSON `json:"origin_server_ts,omitempty"` OriginServerTS RawJSON `json:"origin_server_ts,omitempty"`
Membership rawJSON `json:"membership,omitempty"` Membership RawJSON `json:"membership,omitempty"`
} }
var event eventFields var event eventFields

View file

@ -21,7 +21,7 @@ type FederationRequest struct {
// fields implement the JSON format needed for signing // fields implement the JSON format needed for signing
// specified in https://matrix.org/docs/spec/server_server/unstable.html#request-authentication // specified in https://matrix.org/docs/spec/server_server/unstable.html#request-authentication
fields struct { fields struct {
Content rawJSON `json:"content,omitempty"` Content RawJSON `json:"content,omitempty"`
Destination ServerName `json:"destination"` Destination ServerName `json:"destination"`
Method string `json:"method"` Method string `json:"method"`
Origin ServerName `json:"origin"` Origin ServerName `json:"origin"`
@ -56,7 +56,7 @@ func (r *FederationRequest) SetContent(content interface{}) error {
if err != nil { if err != nil {
return err return err
} }
r.fields.Content = rawJSON(data) r.fields.Content = RawJSON(data)
return nil return nil
} }
@ -252,7 +252,7 @@ func readHTTPRequest(req *http.Request) (*FederationRequest, error) { // nolint:
req.Header.Get("Content-Type"), req.Header.Get("Content-Type"),
) )
} }
result.fields.Content = rawJSON(content) result.fields.Content = RawJSON(content)
} }
for _, authorization := range req.Header["Authorization"] { for _, authorization := range req.Header["Authorization"] {

View file

@ -104,7 +104,7 @@ func TestSignPutRequest(t *testing.T) {
request := NewFederationRequest( request := NewFederationRequest(
"PUT", "localhost:44033", "/_matrix/federation/v1/send/1493385816575/", "PUT", "localhost:44033", "/_matrix/federation/v1/send/1493385816575/",
) )
if err := request.SetContent(rawJSON([]byte(examplePutContent))); err != nil { if err := request.SetContent(RawJSON([]byte(examplePutContent))); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := request.Sign("localhost:8800", "ed25519:a_Obwu", privateKey1); err != nil { if err := request.Sign("localhost:8800", "ed25519:a_Obwu", privateKey1); err != nil {

View file

@ -0,0 +1,132 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tokens
import (
"encoding/base64"
"errors"
"fmt"
"strconv"
"time"
macaroon "gopkg.in/macaroon.v2"
)
const (
macaroonVersion = macaroon.V2
defaultDuration = 2 * 60
// UserPrefix is a common prefix for every user_id caveat
UserPrefix = "user_id = "
// TimePrefix is a common prefix for every expiry caveat
TimePrefix = "time < "
// Gen is a common caveat for every token
Gen = "gen = 1"
)
// TokenOptions represent parameters of Token
type TokenOptions struct {
ServerPrivateKey []byte `yaml:"private_key"`
ServerName string `yaml:"server_name"`
UserID string `json:"user_id"`
Duration int // optional
}
// GenerateLoginToken generates a short term login token to be used as
// token authentication ("m.login.token")
func GenerateLoginToken(op TokenOptions) (string, error) {
if !isValidTokenOptions(op) {
return "", errors.New("The given TokenOptions is invalid")
}
mac, err := generateBaseMacaroon(op.ServerPrivateKey, op.ServerName, op.UserID)
if err != nil {
return "", err
}
if op.Duration == 0 {
op.Duration = defaultDuration
}
now := time.Now().Second()
expiryCaveat := TimePrefix + strconv.Itoa(now+op.Duration)
err = mac.AddFirstPartyCaveat([]byte(expiryCaveat))
if err != nil {
return "", macaroonError(err)
}
urlSafeEncode, err := serializeMacaroon(*mac)
if err != nil {
return "", macaroonError(err)
}
return urlSafeEncode, nil
}
// isValidTokenOptions checks for required fields in a TokenOptions
func isValidTokenOptions(op TokenOptions) bool {
if op.ServerPrivateKey == nil || op.ServerName == "" || op.UserID == "" {
return false
}
return true
}
// generateBaseMacaroon generates a base macaroon common for accessToken & loginToken.
// Returns a macaroon tied with userID,
// returns an error if something goes wrong.
func generateBaseMacaroon(
secret []byte, ServerName string, userID string,
) (*macaroon.Macaroon, error) {
mac, err := macaroon.New(secret, []byte(userID), ServerName, macaroonVersion)
if err != nil {
return nil, macaroonError(err)
}
err = mac.AddFirstPartyCaveat([]byte(Gen))
if err != nil {
return nil, macaroonError(err)
}
err = mac.AddFirstPartyCaveat([]byte(UserPrefix + userID))
if err != nil {
return nil, macaroonError(err)
}
return mac, nil
}
func macaroonError(err error) error {
return fmt.Errorf("Macaroon creation failed: %s", err.Error())
}
// serializeMacaroon takes a macaroon to be serialized.
// returns its base64 encoded string, URL safe, which can be sent via web, email, etc.
func serializeMacaroon(m macaroon.Macaroon) (string, error) {
bin, err := m.MarshalBinary()
if err != nil {
return "", err
}
urlSafeEncode := base64.RawURLEncoding.EncodeToString(bin)
return urlSafeEncode, nil
}
// deSerializeMacaroon takes a base64 encoded string of a macaroon to be de-serialized.
// Returns a macaroon. On failure returns error with description.
func deSerializeMacaroon(urlSafeEncode string) (macaroon.Macaroon, error) {
var mac macaroon.Macaroon
bin, err := base64.RawURLEncoding.DecodeString(urlSafeEncode)
if err != nil {
return mac, err
}
err = mac.UnmarshalBinary(bin)
return mac, err
}

View file

@ -0,0 +1,102 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tokens
import (
"errors"
"strconv"
"strings"
"time"
)
// GetUserFromToken returns the user associated with the token
// Returns the error if something goes wrong.
// Warning: Does not validate the token. Use ValidateToken for that.
func GetUserFromToken(token string) (user string, err error) {
mac, err := deSerializeMacaroon(token)
if err != nil {
return
}
user = string(mac.Id()[:])
return
}
// ValidateToken validates that the Token is understood and was signed by this server.
// Returns nil if token is valid, otherwise returns a error.
func ValidateToken(op TokenOptions, token string) error {
mac, err := deSerializeMacaroon(token)
if err != nil {
return errors.New("Token does not represent a valid macaroon")
}
caveats, err := mac.VerifySignature(op.ServerPrivateKey, nil)
if err != nil {
return errors.New("Provided token was not issued by this server")
}
err = verifyCaveats(caveats, op.UserID)
if err != nil {
return errors.New("Provided token not authorized")
}
return nil
}
// verifyCaveats verifies caveats associated with a login token macaroon.
// which are "gen = 1", "user_id = ...", "time < ..."
// Returns nil on successful verification, else returns an error.
func verifyCaveats(caveats []string, userID string) error {
// variable verified represents a bitmap
// last 4 bits are Uvvv where,
// U: unknownCaveat
// v: caveat to be verified
var verified uint8
now := time.Now().Second()
LoopCaveat:
for _, caveat := range caveats {
switch {
case caveat == Gen:
verified |= 1
case strings.HasPrefix(caveat, UserPrefix):
if caveat[len(UserPrefix):] == userID {
verified |= 2
}
case strings.HasPrefix(caveat, TimePrefix):
if verifyExpiry(caveat[len(TimePrefix):], now) {
verified |= 4
}
default:
verified |= 8
break LoopCaveat
}
}
// Check that all three caveats are verified and no extra caveats
// i.e. Uvvv == 0111
if verified == 7 {
return nil
} else if verified >= 8 {
return errors.New("Unknown caveat present")
}
return errors.New("Required caveats not present")
}
func verifyExpiry(t string, now int) bool {
expiry, err := strconv.Atoi(t)
if err != nil {
return false
}
return now < expiry
}

View file

@ -0,0 +1,92 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tokens
import (
"testing"
"time"
)
var (
// If any of these options are missing, validation should fail
invalidMissings = []string{"ServerPrivateKey", "UserID"}
invalidKeyTokenOp = TokenOptions{
ServerPrivateKey: []byte("notASecretKey"),
UserID: "aRandomUserID",
}
invalidUserTokenOp = TokenOptions{
ServerPrivateKey: []byte("aSecretKey"),
UserID: "notTheSameUserID",
}
)
func expireZeroValidTokenOp() TokenOptions {
op := validTokenOp
op.Duration = 0
return op
}
func TestExpiredLoginToken(t *testing.T) {
fakeToken, err := GenerateLoginToken(expireZeroValidTokenOp())
// token uses 1 second precision
time.Sleep(time.Second)
res := ValidateToken(validTokenOp, fakeToken)
if res == nil {
t.Error("Token validation should fail for expired token")
}
}
func TestValidateToken(t *testing.T) {
fakeToken, err := GenerateLoginToken(validTokenOp)
if err != nil {
t.Errorf("Token generation failed for valid TokenOptions with err: %s", err.Error())
}
// Test validation
res := ValidateToken(validTokenOp, fakeToken)
if res != nil {
t.Error("Token validation failed with response: ", res)
}
// Test validation fails for invalid TokenOp
for _, invalidMissing := range invalidMissings {
res = ValidateToken(invalidTokenOps[invalidMissing], fakeToken)
if res == nil {
t.Errorf("Token validation should fail for TokenOptions with missing %s", invalidMissing)
}
}
for _, invalid := range []TokenOptions{invalidKeyTokenOp, invalidUserTokenOp} {
res = ValidateToken(invalid, fakeToken)
if res == nil {
t.Errorf("Token validation should fail for invalid TokenOptions: ", invalid)
}
}
}
func TestGetUserFromToken(t *testing.T) {
fakeToken, err := GenerateLoginToken(validTokenOp)
if err != nil {
t.Errorf("Token generation failed for valid TokenOptions with err: %s", err.Error())
}
// Test validation
name, err := GetUserFromToken(fakeToken)
if err != nil {
t.Error("Failed to get userID from Token: ", err)
}
if name != validTokenOp.UserID {
t.Error("UserID from Token doesn't match, got: ", name, " expected: ", validTokenOp.UserID)
}
}

View file

@ -0,0 +1,80 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tokens
import (
"testing"
)
var (
validTokenOp = TokenOptions{
ServerPrivateKey: []byte("aSecretKey"),
ServerName: "aRandomServerName",
UserID: "aRandomUserID",
}
invalidTokenOps = map[string]TokenOptions{
"ServerPrivateKey": {
ServerName: "aRandomServerName",
UserID: "aRandomUserID",
},
"ServerName": {
ServerPrivateKey: []byte("aSecretKey"),
UserID: "aRandomUserID",
},
"UserID": {
ServerPrivateKey: []byte("aSecretKey"),
ServerName: "aRandomServerName",
},
}
)
func TestGenerateLoginToken(t *testing.T) {
// Test valid
_, err := GenerateLoginToken(validTokenOp)
if err != nil {
t.Errorf("Token generation failed for valid TokenOptions with err: %s", err.Error())
}
// Test invalids
for missing, invalidTokenOp := range invalidTokenOps {
_, err := GenerateLoginToken(invalidTokenOp)
if err == nil {
t.Errorf("Token generation should fail for TokenOptions with missing %s", missing)
}
}
}
func serializationTestError(err error) string {
return "Token Serialization test failed with err: " + err.Error()
}
func TestSerialization(t *testing.T) {
fakeToken, err := GenerateLoginToken(validTokenOp)
if err != nil {
t.Errorf(serializationTestError(err))
}
fakeMacaroon, err := deSerializeMacaroon(fakeToken)
if err != nil {
t.Errorf(serializationTestError(err))
}
sameFakeToken, err := serializeMacaroon(fakeMacaroon)
if err != nil {
t.Errorf(serializationTestError(err))
}
if sameFakeToken != fakeToken {
t.Errorf("Token Serialization mismatch")
}
}

3
vendor/src/github.com/matrix-org/gomatrixserverlib/travis.sh vendored Normal file → Executable file
View file

@ -16,5 +16,6 @@ go get -u \
github.com/tidwall/sjson \ github.com/tidwall/sjson \
github.com/pkg/errors \ github.com/pkg/errors \
gopkg.in/yaml.v2 \ gopkg.in/yaml.v2 \
gopkg.in/macaroon.v2 \
./hooks/pre-commit ./hooks/pre-commit