Merge branch 'master' of https://github.com/matrix-org/dendrite into recaptcha

This commit is contained in:
Andrew Morgan (https://amorgan.xyz) 2017-11-29 23:50:29 -08:00
commit 11ff9c3e2e
No known key found for this signature in database
GPG key ID: 174BEAB009FD176D
43 changed files with 7974 additions and 251 deletions

View file

@ -2,4 +2,21 @@
set -eu
# make the GIT_DIR and GIT_INDEX_FILE absolute, before we change dir
export GIT_DIR=$(readlink -f `git rev-parse --git-dir`)
if [ -n "${GIT_INDEX_FILE:+x}" ]; then
export GIT_INDEX_FILE=$(readlink -f "$GIT_INDEX_FILE")
fi
# create a temp dir. The `trap` incantation will ensure that it is removed
# again when this script completes.
tmpdir=`mktemp -d`
trap 'rm -rf "$tmpdir"' EXIT
cd "$tmpdir"
# get a clean copy of the index (ie, what has been `git add`ed), so that we can
# run the checks against what we are about to commit, rather than what is in
# the working copy.
git checkout-index -a
./scripts/find-lint.sh fast

View file

@ -1,4 +1,4 @@
// Copyright 2017 Vector Creations Ltd
// Copyright Andrew Morgan <andrew@amorgan.xyz>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
package authtypes
// Flow represents one possible way that the client can authenticate a request.
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api
// https://matrix.org/docs/spec/client_server/r0.3.0.html#user-interactive-authentication-api
type Flow struct {
Stages []LoginType `json:"stages"`
}

View file

@ -23,10 +23,8 @@ import (
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"reflect"
"regexp"
"sort"
"strings"
@ -53,6 +51,7 @@ const (
)
var (
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
sessions = make(map[string][]authtypes.LoginType) // Sessions and completed flow stages
validUsernameRegex = regexp.MustCompile(`^[0-9a-zA-Z_\-./]+$`)
)
@ -79,7 +78,7 @@ type authDict struct {
Session string `json:"session"`
Mac gomatrixserverlib.HexString `json:"mac"`
// ReCaptcha
// Recaptcha
Response string `json:"response"`
// TODO: Lots of custom keys depending on the type
}
@ -101,6 +100,8 @@ type legacyRegisterRequest struct {
Mac gomatrixserverlib.HexString `json:"mac"`
}
// newUserInteractiveResponse will return a struct to be sent back to the client
// during registration.
func newUserInteractiveResponse(
sessionID string,
fs []authtypes.Flow,
@ -119,7 +120,7 @@ type registerResponse struct {
DeviceID string `json:"device_id"`
}
// recaptchaResponse represents the HTTP response from a Google ReCaptcha server
// recaptchaResponse represents the HTTP response from a Google Recaptcha server
type recaptchaResponse struct {
Success bool `json:"success"`
ChallengeTS time.Time `json:"challenge_ts"`
@ -225,9 +226,6 @@ func validateRecaptcha(
return nil
}
// TODO: Create flows in config.go so that they're cached. Always show msisdn flows as long as flows only depend on config-file options.
// Store it just like the config does in a struct and keep it there.
// Register processes a /register request. http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
func Register(
req *http.Request,
@ -246,7 +244,7 @@ func Register(
sessionID := r.Auth.Session
if sessionID == "" {
// Generate a new, random session ID
sessionID = RandString(sessionIDLength)
sessionID = util.RandomString(sessionIDLength)
}
// If no auth type is specified by the client, send back the list of available flows
@ -254,7 +252,7 @@ func Register(
return util.JSONResponse{
Code: 401,
JSON: newUserInteractiveResponse(sessionID,
cfg.Derived.Flows, cfg.Derived.Params),
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
}
}
@ -346,21 +344,19 @@ func handleRegistrationFlow(
}
}
// Check if a registration flow has been completed successfully
for _, flow := range cfg.Derived.Flows {
if checkFlowsEqual(flow, authtypes.Flow{sessions[sessionID]}) {
return completeRegistration(req.Context(), accountDB, deviceDB,
r.Username, r.Password, r.InitialDisplayName)
}
}
// Check if the user's registration flow has been completed successfully
if !checkFlowCompleted(sessions[sessionID], cfg.Derived.Registration.Flows) {
// There are still more stages to complete.
// Return the flows and those that have been completed.
return util.JSONResponse{
Code: 401,
JSON: newUserInteractiveResponse(sessionID,
cfg.Derived.Flows, cfg.Derived.Params),
cfg.Derived.Registration.Flows, cfg.Derived.Registration.Params),
}
}
return completeRegistration(req.Context(), accountDB, deviceDB,
r.Username, r.Password, r.InitialDisplayName)
}
// LegacyRegister process register requests from the legacy v1 API
@ -487,7 +483,7 @@ func isValidMacLogin(
givenMac []byte,
sharedSecret string,
) (bool, error) {
// Double check that username/passowrd don't contain the HMAC delimiters. We should have
// Double check that username/password don't contain the HMAC delimiters. We should have
// already checked this.
if strings.Contains(username, "\x00") {
return false, errors.New("Username contains invalid character")
@ -515,67 +511,60 @@ func isValidMacLogin(
return hmac.Equal(givenMac, expectedMAC), nil
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const (
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
// checkFlows checks a single completed flow against another required one. If
// one contains at least all of the stages that the other does, checkFlows
// returns true.
func checkFlows(
completedStages []authtypes.LoginType,
requiredStages []authtypes.LoginType,
) bool {
// Create temporary slices so they originals will not be modified on sorting
completed := make([]authtypes.LoginType, len(completedStages))
required := make([]authtypes.LoginType, len(requiredStages))
copy(completed, completedStages)
copy(required, requiredStages)
var src = rand.NewSource(time.Now().UnixNano())
// Sort the slices for simple comparison
sort.Slice(completed, func(i, j int) bool { return completed[i] < completed[j] })
sort.Slice(required, func(i, j int) bool { return required[i] < required[j] })
// RandString returns a random string of characters with a given length.
// Do note that it is not thread-safe in its current form.
// https://stackoverflow.com/a/31832326
func RandString(n int) string {
b := make([]byte, n)
// A src.Int63() generates 63 random bits
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}
return string(b)
}
// checkFlowsEqual checks if two registration flows have the same stages
// within them. Order of stages does not matter.
func checkFlowsEqual(aFlow, bFlow authtypes.Flow) bool {
a := aFlow.Stages
b := bFlow.Stages
if len(a) != len(b) {
// Iterate through each slice, going to the next required slice only once
// we've found a match.
i, j := 0, 0
for j < len(required) {
// Exit if we've reached the end of our input without being able to
// match all of the required stages.
if i >= len(completed) {
return false
}
aCopy := make([]string, len(a))
bCopy := make([]string, len(b))
for loginType := range a {
aCopy = append(aCopy, string(loginType))
// If we've found a stage we want, move on to the next required stage.
if completed[i] == required[j] {
j++
}
for loginType := range b {
bCopy = append(bCopy, string(loginType))
i++
}
return true
}
sort.Strings(aCopy)
sort.Strings(bCopy)
return reflect.DeepEqual(aCopy, bCopy)
// checkFlowCompleted checks if a registration flow complies with any allowed flow
// dictated by the server. Order of stages does not matter. A user may complete
// extra stages as long as the required stages of at least one flow is met.
func checkFlowCompleted(flow []authtypes.LoginType, allowedFlows []authtypes.Flow) bool {
// Iterate through possible flows to check whether any have been fully completed.
for _, allowedFlow := range allowedFlows {
if checkFlows(flow, allowedFlow.Stages) {
return true
}
}
return false
}
type availableResponse struct {
Available bool `json:"available"`
}
// RegisterAvailable checks if the username is already taken or invalid
// RegisterAvailable checks if the username is already taken or invalid.
func RegisterAvailable(
req *http.Request,
accountDB *accounts.Database,

View file

@ -0,0 +1,134 @@
// Copyright 2017 Andrew Morgan <andrew@amorgan.xyz>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
var (
// Registration Flows that the server allows.
allowedFlows []authtypes.Flow = []authtypes.Flow{
{
[]authtypes.LoginType{
authtypes.LoginType("stage1"),
authtypes.LoginType("stage2"),
},
},
{
[]authtypes.LoginType{
authtypes.LoginType("stage1"),
authtypes.LoginType("stage3"),
},
},
}
)
// Should return true as we're completing all the stages of a single flow in
// order.
func TestFlowCheckingCompleteFlowOrdered(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage1"),
authtypes.LoginType("stage3"),
}
if !checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be true.")
}
}
// Should return false as all stages in a single flow need to be completed.
func TestFlowCheckingStagesFromDifferentFlows(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage2"),
authtypes.LoginType("stage3"),
}
if checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.")
}
}
// Should return true as we're completing all the stages from a single flow, as
// well as some extraneous stages.
func TestFlowCheckingCompleteOrderedExtraneous(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage1"),
authtypes.LoginType("stage3"),
authtypes.LoginType("stage4"),
authtypes.LoginType("stage5"),
}
if !checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be true.")
}
}
// Should return false as we're submitting an empty flow.
func TestFlowCheckingEmptyFlow(t *testing.T) {
testFlow := []authtypes.LoginType{}
if checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.")
}
}
// Should return false as we've completed a stage that isn't in any allowed flow.
func TestFlowCheckingInvalidStage(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage8"),
}
if checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.")
}
}
// Should return true as we complete all stages of an allowed flow, though out
// of order, as well as extraneous stages.
func TestFlowCheckingExtraneousUnordered(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage5"),
authtypes.LoginType("stage4"),
authtypes.LoginType("stage3"),
authtypes.LoginType("stage2"),
authtypes.LoginType("stage1"),
}
if !checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be true.")
}
}
// Should return false as we're providing fewer stages than are required.
func TestFlowCheckingShortIncorrectInput(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage8"),
}
if checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.")
}
}
// Should return false as we're providing different stages than are required.
func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
testFlow := []authtypes.LoginType{
authtypes.LoginType("stage8"),
authtypes.LoginType("stage9"),
authtypes.LoginType("stage10"),
authtypes.LoginType("stage11"),
}
if checkFlowCompleted(testFlow, allowedFlows) {
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.")
}
}

View file

@ -63,7 +63,7 @@ func Setup(
}},
}
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
v1mux := apiMux.PathPrefix(pathPrefixV1).Subrouter()
@ -131,14 +131,14 @@ func Setup(
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
return RegisterAvailable(req, accountDB)
})).Methods("GET")
})).Methods("GET", "OPTIONS")
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI)
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -152,7 +152,7 @@ func Setup(
vars := mux.Vars(req)
return RemoveLocalAlias(req, device, vars["roomAlias"], aliasAPI)
}),
).Methods("DELETE")
).Methods("DELETE", "OPTIONS")
r0mux.Handle("/logout",
common.MakeAuthAPI("logout", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -191,7 +191,7 @@ func Setup(
JSON: &res,
}
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/user/{userId}/filter",
common.MakeAuthAPI("put_filter", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -205,7 +205,7 @@ func Setup(
vars := mux.Vars(req)
return GetFilter(req, device, accountDB, vars["userId"], vars["filterId"])
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
// Riot user settings
@ -214,14 +214,14 @@ func Setup(
vars := mux.Vars(req)
return GetProfile(req, accountDB, vars["userID"])
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/profile/{userID}/avatar_url",
common.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return GetAvatarURL(req, accountDB, vars["userID"])
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/profile/{userID}/avatar_url",
common.MakeAuthAPI("profile_avatar_url", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -237,7 +237,7 @@ func Setup(
vars := mux.Vars(req)
return GetDisplayName(req, accountDB, vars["userID"])
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/profile/{userID}/displayname",
common.MakeAuthAPI("profile_displayname", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -252,7 +252,7 @@ func Setup(
common.MakeAuthAPI("account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device)
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/account/3pid",
common.MakeAuthAPI("account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -287,7 +287,7 @@ func Setup(
common.MakeAuthAPI("turn_server", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return RequestTurnServer(req, device, cfg)
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
unstableMux.Handle("/thirdparty/protocols",
common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
@ -297,7 +297,7 @@ func Setup(
JSON: struct{}{},
}
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/initialSync",
common.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
@ -307,7 +307,7 @@ func Setup(
JSON: jsonerror.GuestAccessForbidden("Guest access not implemented"),
}
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/user/{userID}/account_data/{type}",
common.MakeAuthAPI("user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -328,14 +328,14 @@ func Setup(
vars := mux.Vars(req)
return GetMemberships(req, device, vars["roomID"], false, cfg, queryAPI)
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/joined_members",
common.MakeAuthAPI("rooms_members", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetMemberships(req, device, vars["roomID"], true, cfg, queryAPI)
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/read_markers",
common.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse {
@ -355,14 +355,14 @@ func Setup(
common.MakeAuthAPI("get_devices", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, deviceDB, device)
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/device/{deviceID}",
common.MakeAuthAPI("get_device", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetDeviceByID(req, deviceDB, device, vars["deviceID"])
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/devices/{deviceID}",
common.MakeAuthAPI("device_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
@ -380,7 +380,7 @@ func Setup(
"end": "",
}}
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/initialSync",
common.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
@ -388,5 +388,5 @@ func Setup(
"end": "",
}}
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
}

View file

@ -239,7 +239,7 @@ func testDownload(host, origin, mediaID string, wantedStatusCode int, serverCmdC
testReq := &test.Request{
Req: req,
WantedStatusCode: wantedStatusCode,
WantedBody: test.CanonicalJSONInput([]string{""})[0],
WantedBody: "",
}
testReq.Run(fmt.Sprintf("download mxc://%v/%v from %v", origin, mediaID, host), timeout, serverCmdChan)
}
@ -263,7 +263,7 @@ func testThumbnail(width, height int, resizeMethod, host string, serverCmdChan c
testReq := &test.Request{
Req: req,
WantedStatusCode: 200,
WantedBody: test.CanonicalJSONInput([]string{""})[0],
WantedBody: "",
}
testReq.Run(fmt.Sprintf("thumbnail mxc://%v/%v%v from %v", testOrigin, testMediaID, query, host), timeout, serverCmdChan)
}

View file

@ -205,14 +205,18 @@ type Dendrite struct {
// Any information derived from the configuration options for later use.
Derived struct {
// Flows for registration. As long as they given flows only relies on config file options,
Registration struct {
// Flows is a slice of flows, which represent one possible way that the client can authenticate a request.
// http://matrix.org/docs/spec/HEAD/client_server/r0.3.0.html#user-interactive-authentication-api
// As long as the generated flows only rely on config file options,
// we can generate them on startup and store them until needed
Flows []authtypes.Flow `json:"flows"`
// Params for registration. Data that is returned to the client while registering and
// that which is necessary to complete certain registration flow stages
// Params that need to be returned to the client during
// registration in order to complete registration stages.
Params map[string]interface{} `json:"params"`
}
}
}
// A Path on the filesystem.
@ -336,26 +340,26 @@ func loadConfig(
}
// derive generates data that is derived from various values provided in
// the config file
// the config file.
func (config *Dendrite) derive() {
// Determine registrations flows based off config values
config.Derived.Params = make(map[string]interface{})
config.Derived.Registration.Params = make(map[string]interface{})
// TODO: Add email auth type
// TODO: Add MSISDN auth type
if config.Matrix.RecaptchaEnabled {
config.Derived.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.Matrix.RecaptchaPublicKey}
config.Derived.Flows = append(config.Derived.Flows,
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.Matrix.RecaptchaPublicKey}
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{[]authtypes.LoginType{authtypes.LoginTypeRecaptcha}})
} else {
config.Derived.Flows = append(config.Derived.Flows,
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
authtypes.Flow{[]authtypes.LoginType{authtypes.LoginTypeDummy}})
}
}
// setDefaults sets default config values if they are not explicitly set
// setDefaults sets default config values if they are not explicitly set.
func (config *Dendrite) setDefaults() {
if config.Matrix.KeyValidityPeriod == 0 {
config.Matrix.KeyValidityPeriod = 24 * time.Hour
@ -376,7 +380,7 @@ func (config *Dendrite) setDefaults() {
}
// Error returns a string detailing how many errors were contained within an
// Error type
// Error type.
func (e Error) Error() string {
if len(e.Problems) == 1 {
return e.Problems[0]
@ -387,7 +391,7 @@ func (e Error) Error() string {
}
// check returns an error type containing all errors found within the config
// file
// file.
func (config *Dendrite) check(monolithic bool) error {
var problems []string
@ -472,7 +476,7 @@ func (config *Dendrite) check(monolithic bool) error {
return nil
}
// absPath returns the absolute path for a given relative or absolute path
// absPath returns the absolute path for a given relative or absolute path.
func absPath(dir string, path Path) string {
if filepath.IsAbs(string(path)) {
// filepath.Join cleans the path so we should clean the absolute paths as well for consistency.

View file

@ -42,11 +42,32 @@ func BuildEvent(
builder *gomatrixserverlib.EventBuilder, cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) (*gomatrixserverlib.Event, error) {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
err := AddPrevEventsToEvent(ctx, builder, queryAPI, queryRes)
if err != nil {
return nil, err
}
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
now := time.Now()
event, err := builder.Build(eventID, now, cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey)
if err != nil {
return nil, err
}
return &event, nil
}
// AddPrevEventsToEvent fills out the prev_events and auth_events fields in builder
func AddPrevEventsToEvent(
ctx context.Context,
builder *gomatrixserverlib.EventBuilder,
queryAPI api.RoomserverQueryAPI, queryRes *api.QueryLatestEventsAndStateResponse,
) error {
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
if err != nil {
return err
}
// Ask the roomserver for information about this room
queryReq := api.QueryLatestEventsAndStateRequest{
RoomID: builder.RoomID,
@ -56,11 +77,11 @@ func BuildEvent(
queryRes = &api.QueryLatestEventsAndStateResponse{}
}
if err = queryAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil {
return nil, err
return err
}
if !queryRes.RoomExists {
return nil, ErrRoomNoExists
return ErrRoomNoExists
}
builder.Depth = queryRes.Depth
@ -71,22 +92,15 @@ func BuildEvent(
for i := range queryRes.StateEvents {
err = authEvents.AddEvent(&queryRes.StateEvents[i])
if err != nil {
return nil, err
return err
}
}
refs, err := eventsNeeded.AuthEventReferences(&authEvents)
if err != nil {
return nil, err
return err
}
builder.AuthEvents = refs
eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), cfg.Matrix.ServerName)
now := time.Now()
event, err := builder.Build(eventID, now, cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey)
if err != nil {
return nil, err
}
return &event, nil
return nil
}

View file

@ -44,6 +44,11 @@ func NewDatabase(dataSourceName string) (*Database, error) {
return d, nil
}
// FetcherName implements KeyFetcher
func (d Database) FetcherName() string {
return "KeyDatabase"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) FetchKeys(
ctx context.Context,

View file

@ -0,0 +1,184 @@
// Copyright 2017 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"encoding/json"
"net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// MakeJoin implements the /make_join API
func MakeJoin(
ctx context.Context,
httpReq *http.Request,
request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite,
query api.RoomserverQueryAPI,
roomID, userID string,
) util.JSONResponse {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("Invalid UserID"),
}
}
if domain != request.Origin() {
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("The join must be sent by the server of the user"),
}
}
// Try building an event for the server
builder := gomatrixserverlib.EventBuilder{
Sender: userID,
RoomID: roomID,
Type: "m.room.member",
StateKey: &userID,
}
err = builder.SetContent(map[string]interface{}{"membership": "join"})
if err != nil {
return httputil.LogThenError(httpReq, err)
}
var queryRes api.QueryLatestEventsAndStateResponse
event, err := common.BuildEvent(ctx, &builder, cfg, query, &queryRes)
if err == common.ErrRoomNoExists {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("Room does not exist"),
}
} else if err != nil {
return httputil.LogThenError(httpReq, err)
}
// Check that the join is allowed or not
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
for i := range queryRes.StateEvents {
stateEvents[i] = &queryRes.StateEvents[i]
}
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(*event, &provider); err != nil {
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden(err.Error()),
}
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{"event": builder},
}
}
// SendJoin implements the /send_join API
func SendJoin(
ctx context.Context,
httpReq *http.Request,
request *gomatrixserverlib.FederationRequest,
cfg config.Dendrite,
query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,
roomID, eventID string,
) util.JSONResponse {
var event gomatrixserverlib.Event
if err := json.Unmarshal(request.Content(), &event); err != nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
// Check that the room ID is correct.
if event.RoomID() != roomID {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("The room ID in the request path must match the room ID in the join event JSON"),
}
}
// Check that the event ID is correct.
if event.EventID() != eventID {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("The event ID in the request path must match the event ID in the join event JSON"),
}
}
// Check that the event is from the server sending the request.
if event.Origin() != request.Origin() {
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"),
}
}
// Check that the event is signed by the server sending the request.
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: event.Origin(),
Message: event.Redact().JSON(),
AtTS: event.OriginServerTS(),
}}
verifyResults, err := keys.VerifyJSONs(ctx, verifyRequests)
if err != nil {
return httputil.LogThenError(httpReq, err)
}
if verifyResults[0].Error != nil {
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("The join must be signed by the server it originated on"),
}
}
// Fetch the state and auth chain. We do this before we send the events
// on, in case this fails.
var stateAndAuthChainRepsonse api.QueryStateAndAuthChainResponse
err = query.QueryStateAndAuthChain(ctx, &api.QueryStateAndAuthChainRequest{
PrevEventIDs: event.PrevEventIDs(),
AuthEventIDs: event.AuthEventIDs(),
RoomID: roomID,
}, &stateAndAuthChainRepsonse)
if err != nil {
return httputil.LogThenError(httpReq, err)
}
// Send the events to the room server.
// We are responsible for notifying other servers that the user has joined
// the room, so set SendAsServer to cfg.Matrix.ServerName
err = producer.SendEvents(ctx, []gomatrixserverlib.Event{event}, cfg.Matrix.ServerName)
if err != nil {
return httputil.LogThenError(httpReq, err)
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
"state": stateAndAuthChainRepsonse.StateEvents,
"auth_chain": stateAndAuthChainRepsonse.AuthChainEvents,
},
}
}

View file

@ -124,6 +124,30 @@ func Setup(
},
)).Methods("GET")
v1fedmux.Handle("/make_join/{roomID}/{userID}", common.MakeFedAPI(
"federation_make_join", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
userID := vars["userID"]
return MakeJoin(
httpReq.Context(), httpReq, request, cfg, query, roomID, userID,
)
},
)).Methods("GET")
v1fedmux.Handle("/send_join/{roomID}/{userID}", common.MakeFedAPI(
"federation_send_join", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
userID := vars["userID"]
return SendJoin(
httpReq.Context(), httpReq, request, cfg, query, producer, keys, roomID, userID,
)
},
)).Methods("PUT")
v1fedmux.Handle("/version", common.MakeExternalAPI(
"federation_version",
func(httpReq *http.Request) util.JSONResponse {

View file

@ -59,10 +59,10 @@ func Setup(
}
r0mux.Handle("/download/{serverName}/{mediaId}",
makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
).Methods("GET")
).Methods("GET", "OPTIONS")
}
func makeDownloadAPI(

View file

@ -36,7 +36,7 @@ func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB *storag
vars := mux.Vars(req)
return directory.GetVisibility(req, publicRoomsDB, vars["roomID"])
}),
).Methods("GET")
).Methods("GET", "OPTIONS")
r0mux.Handle("/directory/list/room/{roomID}",
common.MakeAuthAPI("directory_list", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)

View file

@ -155,6 +155,33 @@ type QueryServerAllowedToSeeEventResponse struct {
AllowedToSeeEvent bool `json:"can_see_event"`
}
// QueryStateAndAuthChainRequest is a request to QueryStateAndAuthChain
type QueryStateAndAuthChainRequest struct {
// The room ID to query the state in.
RoomID string `json:"room_id"`
// The list of prev events for the event. Used to calculate the state at
// the event
PrevEventIDs []string `json:"prev_event_ids"`
// The list of auth events for the event. Used to calculate the auth chain
AuthEventIDs []string `json:"auth_event_ids"`
}
// QueryStateAndAuthChainResponse is a response to QueryStateAndAuthChain
type QueryStateAndAuthChainResponse struct {
// Copy of the request for debugging.
QueryStateAndAuthChainRequest
// Does the room exist on this roomserver?
// If the room doesn't exist this will be false and StateEvents will be empty.
RoomExists bool `json:"room_exists"`
// Do all the previous events exist on this roomserver?
// If some of previous events do not exist this will be false and StateEvents will be empty.
PrevEventsExist bool `json:"prev_events_exist"`
// The state and auth chain events that were requested.
// The lists will be in an arbitrary order.
StateEvents []gomatrixserverlib.Event `json:"state_events"`
AuthChainEvents []gomatrixserverlib.Event `json:"auth_chain_events"`
}
// RoomserverQueryAPI is used to query information from the room server.
type RoomserverQueryAPI interface {
// Query the latest events and state for a room from the room server.
@ -198,6 +225,15 @@ type RoomserverQueryAPI interface {
request *QueryServerAllowedToSeeEventRequest,
response *QueryServerAllowedToSeeEventResponse,
) error
// Query to get state and auth chain for a (potentially hypothetical) event.
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
// the state and auth chain to return.
QueryStateAndAuthChain(
ctx context.Context,
request *QueryStateAndAuthChainRequest,
response *QueryStateAndAuthChainResponse,
) error
}
// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API.
@ -218,6 +254,9 @@ const RoomserverQueryInvitesForUserPath = "/api/roomserver/queryInvitesForUser"
// RoomserverQueryServerAllowedToSeeEventPath is the HTTP path for the QueryServerAllowedToSeeEvent API
const RoomserverQueryServerAllowedToSeeEventPath = "/api/roomserver/queryServerAllowedToSeeEvent"
// RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API
const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain"
// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API.
// If httpClient is nil then it uses the http.DefaultClient
func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI {
@ -310,6 +349,19 @@ func (h *httpRoomserverQueryAPI) QueryServerAllowedToSeeEvent(
return postJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// QueryStateAndAuthChain implements RoomserverQueryAPI
func (h *httpRoomserverQueryAPI) QueryStateAndAuthChain(
ctx context.Context,
request *QueryStateAndAuthChainRequest,
response *QueryStateAndAuthChainResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath
return postJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func postJSON(
ctx context.Context, span opentracing.Span, httpClient *http.Client,
apiURL string, request, response interface{},

View file

@ -27,9 +27,19 @@ import (
"github.com/matrix-org/util"
)
// RoomserverQueryAPIEventDB has a convenience API to fetch events directly by
// EventIDs.
type RoomserverQueryAPIEventDB interface {
// Look up the Events for a list of event IDs. Does not error if event was
// not found.
// Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
}
// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API.
type RoomserverQueryAPIDatabase interface {
state.RoomStateDatabase
RoomserverQueryAPIEventDB
// Look up the numeric ID for the room.
// Returns 0 if the room doesn't exists.
// Returns an error if there was a problem talking to the database.
@ -418,6 +428,98 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
return nil
}
// QueryStateAndAuthChain implements api.RoomserverQueryAPI
func (r *RoomserverQueryAPI) QueryStateAndAuthChain(
ctx context.Context,
request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse,
) error {
response.QueryStateAndAuthChainRequest = *request
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
if roomNID == 0 {
return nil
}
response.RoomExists = true
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil {
switch err.(type) {
case types.MissingEventError:
return nil
default:
return err
}
}
response.PrevEventsExist = true
// Look up the currrent state for the requested tuples.
stateEntries, err := state.LoadCombinedStateAfterEvents(
ctx, r.DB, prevStates,
)
if err != nil {
return err
}
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return err
}
response.StateEvents = stateEvents
response.AuthChainEvents, err = getAuthChain(ctx, r.DB, request.AuthEventIDs)
return err
}
// getAuthChain fetches the auth chain for the given auth events.
// An auth chain is the list of all events that are referenced in the
// auth_events section, and all their auth_events, recursively.
// The returned set of events contain the given events.
// Will *not* error if we don't have all auth events.
func getAuthChain(
ctx context.Context, dB RoomserverQueryAPIEventDB, authEventIDs []string,
) ([]gomatrixserverlib.Event, error) {
var authEvents []gomatrixserverlib.Event
// List of event ids to fetch. These will be added to the result and
// their auth events will be fetched (if they haven't been previously)
eventsToFetch := authEventIDs
// Set of events we've already fetched.
fetchedEventMap := make(map[string]bool)
// Check if there's anything left to do
for len(eventsToFetch) > 0 {
// Convert eventIDs to events. First need to fetch NIDs
events, err := dB.EventsFromIDs(ctx, eventsToFetch)
if err != nil {
return nil, err
}
// Work out a) which events we should add to the returned list of
// events and b) which of the auth events we haven't seen yet and
// add them to the list of events to fetch.
eventsToFetch = eventsToFetch[:0]
for _, event := range events {
fetchedEventMap[event.EventID()] = true
authEvents = append(authEvents, event.Event)
// Now we need to fetch any auth events that we haven't
// previously seen.
for _, authEventID := range event.AuthEventIDs() {
if !fetchedEventMap[authEventID] {
fetchedEventMap[authEventID] = true
eventsToFetch = append(eventsToFetch, authEventID)
}
}
}
}
return authEvents, nil
}
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
// nolint: gocyclo
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
@ -505,4 +607,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
return util.JSONResponse{Code: 200, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryStateAndAuthChainPath,
common.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAndAuthChainRequest
var response api.QueryStateAndAuthChainResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: 200, JSON: &response}
}),
)
}

View file

@ -0,0 +1,174 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package query
import (
"context"
"encoding/json"
"testing"
"sort"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// used to implement RoomserverQueryAPIEventDB to test getAuthChain
type getEventDB struct {
eventMap map[string]gomatrixserverlib.Event
}
func createEventDB() *getEventDB {
return &getEventDB{
eventMap: make(map[string]gomatrixserverlib.Event),
}
}
// Adds a fake event to the storage with given auth events.
func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error {
authEvents := []gomatrixserverlib.EventReference{}
for _, authID := range authIDs {
authEvents = append(authEvents, gomatrixserverlib.EventReference{
EventID: authID,
})
}
builder := map[string]interface{}{
"event_id": eventID,
"auth_events": authEvents,
}
eventJSON, err := json.Marshal(&builder)
if err != nil {
return err
}
event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false)
if err != nil {
return err
}
db.eventMap[eventID] = event
return nil
}
// Adds multiple events at once, each entry in the map is an eventID and set of
// auth events that are converted to an event and added.
func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
for eventID, authIDs := range graph {
err := db.addFakeEvent(eventID, authIDs)
if err != nil {
return err
}
}
return nil
}
// EventsFromIDs implements RoomserverQueryAPIEventDB
func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs {
res = append(res, types.Event{
EventNID: 0,
Event: db.eventMap[evID],
})
}
return
}
// Returns if the slices are equal after sorting them.
func compareUnsortedStringSlices(a []string, b []string) bool {
if len(a) != len(b) {
return false
}
sort.Strings(a)
sort.Strings(b)
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestGetAuthChainSingle(t *testing.T) {
db := createEventDB()
err := db.addFakeEvents(map[string][]string{
"a": {},
"b": {"a"},
"c": {"a", "b"},
"d": {"b", "c"},
"e": {"a", "d"},
})
if err != nil {
t.Fatalf("Failed to add events to db: %v", err)
}
result, err := getAuthChain(context.TODO(), db, []string{"e"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
var returnedIDs []string
for _, event := range result {
returnedIDs = append(returnedIDs, event.EventID())
}
expectedIDs := []string{"a", "b", "c", "d", "e"}
if !compareUnsortedStringSlices(expectedIDs, returnedIDs) {
t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs)
}
}
func TestGetAuthChainMultiple(t *testing.T) {
db := createEventDB()
err := db.addFakeEvents(map[string][]string{
"a": {},
"b": {"a"},
"c": {"a", "b"},
"d": {"b", "c"},
"e": {"a", "d"},
"f": {"a", "b", "c"},
})
if err != nil {
t.Fatalf("Failed to add events to db: %v", err)
}
result, err := getAuthChain(context.TODO(), db, []string{"e", "f"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
var returnedIDs []string
for _, event := range result {
returnedIDs = append(returnedIDs, event.EventID())
}
expectedIDs := []string{"a", "b", "c", "d", "e", "f"}
if !compareUnsortedStringSlices(expectedIDs, returnedIDs) {
t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs)
}
}

View file

@ -651,6 +651,21 @@ func (d *Database) GetMembershipEventNIDsForRoom(
return d.statements.selectMembershipsFromRoom(ctx, roomNID)
}
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid)
}
return d.Events(ctx, nids)
}
type transaction struct {
ctx context.Context
txn *sql.Tx

View file

@ -34,20 +34,20 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServer
r0mux.Handle("/sync", common.MakeAuthAPI("sync", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
})).Methods("GET")
})).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI("room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateRequest(req, syncDB, vars["roomID"])
})).Methods("GET")
})).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state/{type}", common.MakeAuthAPI("room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], "")
})).Methods("GET")
})).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", common.MakeAuthAPI("room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"])
})).Methods("GET")
})).Methods("GET", "OPTIONS")
}

20
vendor/manifest vendored
View file

@ -135,7 +135,7 @@
{
"importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "076933f95312aae3a9476e78d6b4118e1b45d542",
"revision": "8540d3dfc13c797cd3200640bc06e0286ab355aa",
"branch": "master"
},
{
@ -274,6 +274,24 @@
"branch": "master",
"path": "/require"
},
{
"importpath": "github.com/tidwall/gjson",
"repository": "https://github.com/tidwall/gjson",
"revision": "67e2a63ac70d273b6bc7589f12f07180bc9fc189",
"branch": "master"
},
{
"importpath": "github.com/tidwall/match",
"repository": "https://github.com/tidwall/match",
"revision": "1731857f09b1f38450e2c12409748407822dc6be",
"branch": "master"
},
{
"importpath": "github.com/tidwall/sjson",
"repository": "https://github.com/tidwall/sjson",
"revision": "6a22caf2fd45d5e2119bfc3717e984f15a7eb7ee",
"branch": "master"
},
{
"importpath": "github.com/tj/go-debug",
"repository": "https://github.com/tj/go-debug",

View file

@ -175,7 +175,29 @@ func (fc *Client) LookupUserInfo(
return
}
// LookupServerKeys lookups up the keys for a matrix server from a matrix server.
// GetServerKeys asks a matrix server for its signing keys and TLS cert
func (fc *Client) GetServerKeys(
ctx context.Context, matrixServer ServerName,
) (ServerKeys, error) {
url := url.URL{
Scheme: "matrix",
Host: string(matrixServer),
Path: "/_matrix/key/v2/server",
}
var body ServerKeys
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return body, err
}
err = fc.DoRequestAndParseResponse(
ctx, req, &body,
)
return body, err
}
// LookupServerKeys looks up the keys for a matrix server from a matrix server.
// The first argument is the name of the matrix server to download the keys from.
// The second argument is a map from (server name, key ID) pairs to timestamps.
// The (server name, key ID) pair identifies the key to download.

View file

@ -16,11 +16,13 @@
package gomatrixserverlib
import (
"bytes"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/tidwall/sjson"
"golang.org/x/crypto/ed25519"
)
@ -183,38 +185,53 @@ func (eb *EventBuilder) Build(eventID string, now time.Time, origin ServerName,
// It also checks the content hashes to ensure the event has not been tampered with.
// This should be used when receiving new events from remote servers.
func NewEventFromUntrustedJSON(eventJSON []byte) (result Event, err error) {
var event map[string]rawJSON
if err = json.Unmarshal(eventJSON, &event); err != nil {
return
}
// Synapse removes these keys from events in case a server accidentally added them.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/crypto/event_signing.py#L57-L62
delete(event, "outlier")
delete(event, "destinations")
delete(event, "age_ts")
if eventJSON, err = json.Marshal(event); err != nil {
return
}
if err = checkEventContentHash(eventJSON); err != nil {
result.redacted = true
// If the content hash doesn't match then we have to discard all non-essential fields
// because they've been tampered with.
if eventJSON, err = redactEvent(eventJSON); err != nil {
return
}
}
if eventJSON, err = CanonicalJSON(eventJSON); err != nil {
return
}
result.eventJSON = eventJSON
// We parse the JSON early on so that we don't have to check if the JSON
// is valid
if err = json.Unmarshal(eventJSON, &result.fields); err != nil {
return
}
// Synapse removes these keys from events in case a server accidentally added them.
// https://github.com/matrix-org/synapse/blob/v0.18.5/synapse/crypto/event_signing.py#L57-L62
for _, key := range []string{"outlier", "destinations", "age_ts"} {
if eventJSON, err = sjson.DeleteBytes(eventJSON, key); err != nil {
return
}
}
// We know the JSON must be valid here.
eventJSON = CanonicalJSONAssumeValid(eventJSON)
if err = checkEventContentHash(eventJSON); err != nil {
result.redacted = true
// If the content hash doesn't match then we have to discard all non-essential fields
// because they've been tampered with.
var redactedJSON []byte
if redactedJSON, err = redactEvent(eventJSON); err != nil {
return
}
redactedJSON = CanonicalJSONAssumeValid(redactedJSON)
// We need to ensure that `result` is the redacted event.
// If redactedJSON is the same as eventJSON then `result` is already
// correct. If not then we need to reparse.
//
// Yes, this means that for some events we parse twice (which is slow),
// but means that parsing unredacted events is fast.
if !bytes.Equal(redactedJSON, eventJSON) {
result = Event{redacted: true}
if err = json.Unmarshal(redactedJSON, &result.fields); err != nil {
return
}
}
eventJSON = redactedJSON
}
result.eventJSON = eventJSON
if err = result.CheckFields(); err != nil {
return
}

View file

@ -0,0 +1,52 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package gomatrixserverlib
import (
"encoding/json"
"testing"
)
func benchmarkParse(b *testing.B, eventJSON string) {
var event Event
// run the Unparse function b.N times
for n := 0; n < b.N; n++ {
if err := json.Unmarshal([]byte(eventJSON), &event); err != nil {
b.Error("Failed to parse event")
}
}
}
// Benchmark a more complicated event, in this case a power levels event.
func BenchmarkParseLargerEvent(b *testing.B) {
benchmarkParse(b, `{"auth_events":[["$Stdin0028C5qBjz5:localhost",{"sha256":"PvTyW+Mfb0aCajkIlBk1XlQE+1uVco3to8C2+/1J7iQ"}],["$klXtjBwwDQIGglax:localhost",{"sha256":"hLoiSkcGLZJr5wkIDA8+bujNJPsYX1SOCCXIErHEcgM"}]],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100},"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@test:localhost":100},"users_default":0},"depth":3,"event_id":"$7gPR7SLdkfDsMvJL:localhost","hashes":{"sha256":"/kQnrzO5vhbnwyGvKso4CVMRyyryiyanq6t27mt5kSw"},"origin":"localhost","origin_server_ts":1510854446548,"prev_events":[["$klXtjBwwDQIGglax:localhost",{"sha256":"hLoiSkcGLZJr5wkIDA8+bujNJPsYX1SOCCXIErHEcgM"}]],"prev_state":[],"room_id":"!pUjJbIC8V32G0FLt:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"NOxjrcci7AIRhcTVmJ6nrsslLsaOJzB0iusDZ6cOFrv2OXkDY7mrBM3cQQS3DhGWltEtu3OC0nsvkfeYtwr9DQ"}},"state_key":"","type":"m.room.power_levels"}`)
}
// Lets now test parsing a smaller name event, first one that is valid, then wrong hash, and then the redacted one
func BenchmarkParseSmallerEvent(b *testing.B) {
benchmarkParse(b, `{"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name"}`)
}
func BenchmarkParseSmallerEventFailedHash(b *testing.B) {
benchmarkParse(b, `{"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test4"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name"}`)
}
func BenchmarkParseSmallerEventRedacted(b *testing.B) {
benchmarkParse(b, `{"event_id":"$yvN1b43rlmcOs5fY:localhost","sender":"@test:localhost","room_id":"!19Mp0U9hjajeIiw1:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"content":{},"type":"m.room.name","state_key":"","depth":7,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"origin":"localhost","origin_server_ts":1510854416361}`)
}

View file

@ -22,6 +22,8 @@ import (
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/crypto/ed25519"
)
@ -68,40 +70,28 @@ func addContentHashesToEvent(eventJSON []byte) ([]byte, error) {
}
// checkEventContentHash checks if the unredacted content of the event matches the SHA-256 hash under the "hashes" key.
// Assumes that eventJSON has been canonicalised already.
func checkEventContentHash(eventJSON []byte) error {
var event map[string]rawJSON
var err error
if err := json.Unmarshal(eventJSON, &event); err != nil {
result := gjson.GetBytes(eventJSON, "hashes.sha256")
var hash Base64String
if err = hash.Decode(result.Str); err != nil {
return err
}
hashesJSON := event["hashes"]
hashableEventJSON := eventJSON
delete(event, "signatures")
delete(event, "unsigned")
delete(event, "hashes")
var hashes struct {
Sha256 Base64String `json:"sha256"`
}
if err := json.Unmarshal(hashesJSON, &hashes); err != nil {
for _, key := range []string{"signatures", "unsigned", "hashes"} {
if hashableEventJSON, err = sjson.DeleteBytes(hashableEventJSON, key); err != nil {
return err
}
hashableEventJSON, err := json.Marshal(event)
if err != nil {
return err
}
hashableEventJSON, err = CanonicalJSON(hashableEventJSON)
if err != nil {
return err
}
sha256Hash := sha256.Sum256(hashableEventJSON)
if !bytes.Equal(sha256Hash[:], []byte(hashes.Sha256)) {
return fmt.Errorf("Invalid Sha256 content hash: %v != %v", sha256Hash[:], []byte(hashes.Sha256))
if !bytes.Equal(sha256Hash[:], []byte(hash)) {
return fmt.Errorf("Invalid Sha256 content hash: %v != %v", sha256Hash[:], []byte(hash))
}
return nil

View file

@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/matrix-org/util"
)
// A RespSend is the content of a response to PUT /_matrix/federation/v1/send/{txnID}/
@ -109,6 +111,7 @@ func (r RespState) Events() ([]Event, error) {
// Check that a response to /state is valid.
func (r RespState) Check(ctx context.Context, keyRing JSONVerifier) error {
logger := util.GetLogger(ctx)
var allEvents []Event
for _, event := range r.AuthEvents {
if event.StateKey() == nil {
@ -134,8 +137,9 @@ func (r RespState) Check(ctx context.Context, keyRing JSONVerifier) error {
}
// Check if the events pass signature checks.
logger.Infof("Checking event signatures for %d events of room state", len(allEvents))
if err := VerifyEventSignatures(ctx, allEvents, keyRing); err != nil {
return nil
return err
}
eventsByID := map[string]*Event{}

View file

@ -2,6 +2,25 @@
set -eu
# make the GIT_DIR and GIT_INDEX_FILE absolute, before we change dir
export GIT_DIR=$(readlink -f `git rev-parse --git-dir`)
if [ -n "${GIT_INDEX_FILE:+x}" ]; then
export GIT_INDEX_FILE=$(readlink -f "$GIT_INDEX_FILE")
fi
wd=`pwd`
# create a temp dir. The `trap` incantation will ensure that it is removed
# again when this script completes.
tmpdir=`mktemp -d`
trap 'rm -rf "$tmpdir"' EXIT
cd "$tmpdir"
# get a clean copy of the index (ie, what has been `git add`ed), so that we can
# run the checks against what we are about to commit, rather than what is in
# the working copy.
git checkout-index -a
echo "Installing lint search engine..."
go get github.com/alecthomas/gometalinter/
gometalinter --config=linter.json --install --update

View file

@ -16,66 +16,73 @@
package gomatrixserverlib
import (
"bytes"
"encoding/binary"
"encoding/json"
"sort"
"unicode/utf8"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
)
// CanonicalJSON re-encodes the JSON in a canonical encoding. The encoding is
// the shortest possible encoding using integer values with sorted object keys.
// https://matrix.org/docs/spec/server_server/unstable.html#canonical-json
func CanonicalJSON(input []byte) ([]byte, error) {
sorted, err := SortJSON(input, make([]byte, 0, len(input)))
if err != nil {
return nil, err
if !gjson.Valid(string(input)) {
return nil, errors.Errorf("invalid json")
}
return CompactJSON(sorted, make([]byte, 0, len(sorted))), nil
return CanonicalJSONAssumeValid(input), nil
}
// CanonicalJSONAssumeValid is the same as CanonicalJSON, but assumes the
// input is valid JSON
func CanonicalJSONAssumeValid(input []byte) []byte {
input = CompactJSON(input, make([]byte, 0, len(input)))
return SortJSON(input, make([]byte, 0, len(input)))
}
// SortJSON reencodes the JSON with the object keys sorted by lexicographically
// by codepoint. The input must be valid JSON.
func SortJSON(input, output []byte) ([]byte, error) {
// Skip to the first character that isn't whitespace.
var decoded interface{}
func SortJSON(input, output []byte) []byte {
result := gjson.ParseBytes(input)
decoder := json.NewDecoder(bytes.NewReader(input))
decoder.UseNumber()
if err := decoder.Decode(&decoded); err != nil {
return nil, err
}
return sortJSONValue(decoded, output)
rawJSON := rawJSONFromResult(result, input)
return sortJSONValue(result, rawJSON, output)
}
func sortJSONValue(input interface{}, output []byte) ([]byte, error) {
switch value := input.(type) {
case []interface{}:
// If the JSON is an array then we need to sort the keys of its children.
return sortJSONArray(value, output)
case map[string]interface{}:
// If the JSON is an object then we need to sort its keys and the keys of its children.
return sortJSONObject(value, output)
default:
// Otherwise the JSON is a value and can be encoded without any further sorting.
bytes, err := json.Marshal(value)
if err != nil {
return nil, err
// sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the
// raw JSON bytes that gjson.Result points to.
func sortJSONValue(input gjson.Result, inputJSON, output []byte) []byte {
if input.IsArray() {
return sortJSONArray(input, inputJSON, output)
}
return append(output, bytes...), nil
if input.IsObject() {
return sortJSONObject(input, inputJSON, output)
}
// If its neither an object nor an array then there is no sub structure
// to sort, so just append the raw bytes.
return append(output, inputJSON...)
}
func sortJSONArray(input []interface{}, output []byte) ([]byte, error) {
var err error
// sortJSONArray takes a gjson.Result and sorts it, assuming its an array.
// inputJSON must be the raw JSON bytes that gjson.Result points to.
func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {
sep := byte('[')
for _, value := range input {
// Iterate over each value in the array and sort it.
input.ForEach(func(_, value gjson.Result) bool {
output = append(output, sep)
sep = ','
if output, err = sortJSONValue(value, output); err != nil {
return nil, err
}
}
rawJSON := rawJSONFromResult(value, inputJSON)
output = sortJSONValue(value, rawJSON, output)
return true // keep iterating
})
if sep == '[' {
// If sep is still '[' then the array was empty and we never wrote the
// initial '[', so we write it now along with the closing ']'.
@ -84,31 +91,49 @@ func sortJSONArray(input []interface{}, output []byte) ([]byte, error) {
// Otherwise we end the array by writing a single ']'
output = append(output, ']')
}
return output, nil
return output
}
func sortJSONObject(input map[string]interface{}, output []byte) ([]byte, error) {
var err error
keys := make([]string, len(input))
var j int
for key := range input {
keys[j] = key
j++
// sortJSONObject takes a gjson.Result and sorts it, assuming its an object.
// inputJSON must be the raw JSON bytes that gjson.Result points to.
func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
type entry struct {
key string // The parsed key string
rawKey []byte // The raw, unparsed key JSON string
value gjson.Result
}
sort.Strings(keys)
var entries []entry
// Iterate over each key/value pair and add it to a slice
// that we can sort
input.ForEach(func(key, value gjson.Result) bool {
entries = append(entries, entry{
key: key.String(),
rawKey: rawJSONFromResult(key, inputJSON),
value: value,
})
return true // keep iterating
})
// Sort the slice based on the *parsed* key
sort.Slice(entries, func(a, b int) bool {
return entries[a].key < entries[b].key
})
sep := byte('{')
for _, key := range keys {
for _, entry := range entries {
output = append(output, sep)
sep = ','
var encoded []byte
if encoded, err = json.Marshal(key); err != nil {
return nil, err
}
output = append(output, encoded...)
// Append the raw unparsed JSON key, *not* the parsed key
output = append(output, entry.rawKey...)
output = append(output, ':')
if output, err = sortJSONValue(input[key], output); err != nil {
return nil, err
}
rawJSON := rawJSONFromResult(entry.value, inputJSON)
output = sortJSONValue(entry.value, rawJSON, output)
}
if sep == '{' {
// If sep is still '{' then the object was empty and we never wrote the
@ -118,7 +143,7 @@ func sortJSONObject(input map[string]interface{}, output []byte) ([]byte, error)
// Otherwise we end the object by writing a single '}'
output = append(output, '}')
}
return output, nil
return output
}
// CompactJSON makes the encoded JSON as small as possible by removing
@ -237,3 +262,19 @@ func readHexDigits(input []byte) uint32 {
hex |= hex >> 8
return hex & 0xFFFF
}
// rawJSONFromResult extracts the raw JSON bytes pointed to by result.
// input must be the json bytes that were used to generate result
func rawJSONFromResult(result gjson.Result, input []byte) (rawJSON []byte) {
// This is lifted from gjson README. Basically, result.Raw is a copy of
// the bytes we want, but its more efficient to take a slice.
// If Index is 0 then for some reason we can't extract it from the original
// JSON bytes.
if result.Index > 0 {
rawJSON = input[result.Index : result.Index+len(result.Raw)]
} else {
rawJSON = []byte(result.Raw)
}
return
}

View file

@ -20,10 +20,8 @@ import (
)
func testSortJSON(t *testing.T, input, want string) {
got, err := SortJSON([]byte(input), nil)
if err != nil {
t.Error(err)
}
got := SortJSON([]byte(input), nil)
// Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
if string(CompactJSON(got, nil)) != want {
t.Errorf("SortJSON(%q): want %q got %q", input, want, got)
@ -36,6 +34,7 @@ func TestSortJSON(t *testing.T) {
`{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`)
testSortJSON(t, `[true,false,null]`, `[true,false,null]`)
testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`)
testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`)
}
func testCompactJSON(t *testing.T, input, want string) {

View file

@ -6,6 +6,7 @@ import (
"strings"
"time"
"github.com/matrix-org/util"
"golang.org/x/crypto/ed25519"
)
@ -60,6 +61,10 @@ type KeyFetcher interface {
// The result may have more (server name, key ID) pairs than were in the request.
// Returns an error if there was a problem fetching the keys.
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]PublicKeyLookupResult, error)
// FetcherName returns the name of this fetcher, which can then be used for
// logging errors etc.
FetcherName() string
}
// A KeyDatabase is a store for caching public keys.
@ -113,6 +118,7 @@ type JSONVerifier interface {
// VerifyJSONs implements JSONVerifier.
func (k KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
logger := util.GetLogger(ctx)
results := make([]VerifyJSONResult, len(requests))
keyIDs := make([][]KeyID, len(requests))
@ -154,7 +160,7 @@ func (k KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest)
}
k.checkUsingKeys(requests, results, keyIDs, keysFromDatabase)
for i := range k.KeyFetchers {
for _, fetcher := range k.KeyFetchers {
// TODO: we should distinguish here between expired keys, and those we don't have.
// If the key has expired, it's no use re-requesting it.
keyRequests := k.publicKeyRequests(requests, results, keyIDs)
@ -163,12 +169,22 @@ func (k KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest)
// This means that we've checked every JSON object we can check.
return results, nil
}
fetcherLogger := logger.WithField("fetcher", fetcher.FetcherName())
// TODO: Coalesce in-flight requests for the same keys.
// Otherwise we risk spamming the servers we query the keys from.
keysFetched, err := k.KeyFetchers[i].FetchKeys(ctx, keyRequests)
fetcherLogger.WithField("num_key_requests", len(keyRequests)).
Info("Requesting keys from fetcher")
keysFetched, err := fetcher.FetchKeys(ctx, keyRequests)
if err != nil {
return nil, err
}
fetcherLogger.WithField("num_keys_fetched", len(keysFetched)).
Info("Got keys from fetcher")
k.checkUsingKeys(requests, results, keyIDs, keysFetched)
// Add the keys to the database so that we won't need to fetch them again.
@ -259,6 +275,11 @@ type PerspectiveKeyFetcher struct {
Client Client
}
// FetcherName implements KeyFetcher
func (p PerspectiveKeyFetcher) FetcherName() string {
return fmt.Sprintf("perspective server %s", p.PerspectiveServerName)
}
// FetchKeys implements KeyFetcher
func (p *PerspectiveKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
@ -303,7 +324,8 @@ func (p *PerspectiveKeyFetcher) FetchKeys(
return nil, fmt.Errorf("gomatrixserverlib: key response from perspective server failed checks")
}
// TODO: What happens if the same key ID appears in multiple responses?
// TODO (matrix-org/dendrite#345): What happens if the same key ID
// appears in multiple responses?
// We should probably take the response with the highest valid_until_ts.
mapServerKeysToPublicKeyLookupResult(keys, results)
}
@ -318,6 +340,11 @@ type DirectKeyFetcher struct {
Client Client
}
// FetcherName implements KeyFetcher
func (d DirectKeyFetcher) FetcherName() string {
return "DirectKeyFetcher"
}
// FetchKeys implements KeyFetcher
func (d *DirectKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
@ -333,9 +360,9 @@ func (d *DirectKeyFetcher) FetchKeys(
}
results := map[PublicKeyRequest]PublicKeyLookupResult{}
for server, reqs := range byServer {
for server := range byServer {
// TODO: make these requests in parallel
serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
serverResults, err := d.fetchKeysForServer(ctx, server)
if err != nil {
// TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
return nil, err
@ -348,25 +375,23 @@ func (d *DirectKeyFetcher) FetchKeys(
}
func (d *DirectKeyFetcher) fetchKeysForServer(
ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
ctx context.Context, serverName ServerName,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {
serverKeys, err := d.Client.LookupServerKeys(ctx, serverName, requests)
keys, err := d.Client.GetServerKeys(ctx, serverName)
if err != nil {
return nil, err
}
results := map[PublicKeyRequest]PublicKeyLookupResult{}
for _, keys := range serverKeys {
// Check that the keys are valid for the server.
checks, _, _ := CheckKeys(serverName, time.Unix(0, 0), keys, nil)
if !checks.AllChecksOK {
return nil, fmt.Errorf("gomatrixserverlib: key response direct from %q failed checks", serverName)
}
// TODO: What happens if the same key ID appears in multiple responses?
// We should probably take the response with the highest valid_until_ts.
results := map[PublicKeyRequest]PublicKeyLookupResult{}
// TODO (matrix-org/dendrite#345): What happens if the same key ID
// appears in multiple responses? We should probably reject the response.
mapServerKeysToPublicKeyLookupResult(keys, results)
}
return results, nil
}

View file

@ -36,6 +36,10 @@ var testKeys = `{
type testKeyDatabase struct{}
func (db testKeyDatabase) FetcherName() string {
return "testKeyDatabase"
}
func (db *testKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {
@ -151,6 +155,11 @@ func (e *erroringKeyDatabaseError) Error() string { return "An error with the ke
var testErrorFetch = erroringKeyDatabaseError(1)
var testErrorStore = erroringKeyDatabaseError(2)
// FetcherName implements KeyFetcher
func (e erroringKeyDatabase) FetcherName() string {
return "ErroringKeyDatabase"
}
func (e *erroringKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {

View file

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2016 Josh Baker
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,373 @@
<p align="center">
<img
src="logo.png"
width="240" height="78" border="0" alt="GJSON">
<br>
<a href="https://travis-ci.org/tidwall/gjson"><img src="https://img.shields.io/travis/tidwall/gjson.svg?style=flat-square" alt="Build Status"></a>
<a href="https://godoc.org/github.com/tidwall/gjson"><img src="https://img.shields.io/badge/api-reference-blue.svg?style=flat-square" alt="GoDoc"></a>
<a href="http://tidwall.com/gjson-play"><img src="https://img.shields.io/badge/play-ground-orange.svg?style=flat-square" alt="GJSON Playground"></a>
</p>
<p align="center">get a json value quickly</a></p>
GJSON is a Go package that provides a [fast](#performance) and [simple](#get-a-value) way to get values from a json document.
It has features such as [one line retrieval](#get-a-value), [dot notation paths](#path-syntax), [iteration](#iterate-through-an-object-or-array).
Getting Started
===============
## Installing
To start using GJSON, install Go and run `go get`:
```sh
$ go get -u github.com/tidwall/gjson
```
This will retrieve the library.
## Get a value
Get searches json for the specified path. A path is in dot syntax, such as "name.last" or "age". This function expects that the json is well-formed. Bad json will not panic, but it may return back unexpected results. When the value is found it's returned immediately.
```go
package main
import "github.com/tidwall/gjson"
const json = `{"name":{"first":"Janet","last":"Prichard"},"age":47}`
func main() {
value := gjson.Get(json, "name.last")
println(value.String())
}
```
This will print:
```
Prichard
```
*There's also the [GetMany](#get-multiple-values-at-once) function to get multiple values at once, and [GetBytes](#working-with-bytes) for working with JSON byte slices.*
## Path Syntax
A path is a series of keys separated by a dot.
A key may contain special wildcard characters '\*' and '?'.
To access an array value use the index as the key.
To get the number of elements in an array or to access a child path, use the '#' character.
The dot and wildcard characters can be escaped with '\\'.
```json
{
"name": {"first": "Tom", "last": "Anderson"},
"age":37,
"children": ["Sara","Alex","Jack"],
"fav.movie": "Deer Hunter",
"friends": [
{"first": "Dale", "last": "Murphy", "age": 44},
{"first": "Roger", "last": "Craig", "age": 68},
{"first": "Jane", "last": "Murphy", "age": 47}
]
}
```
```
"name.last" >> "Anderson"
"age" >> 37
"children" >> ["Sara","Alex","Jack"]
"children.#" >> 3
"children.1" >> "Alex"
"child*.2" >> "Jack"
"c?ildren.0" >> "Sara"
"fav\.movie" >> "Deer Hunter"
"friends.#.first" >> ["Dale","Roger","Jane"]
"friends.1.last" >> "Craig"
```
You can also query an array for the first match by using `#[...]`, or find all matches with `#[...]#`.
Queries support the `==`, `!=`, `<`, `<=`, `>`, `>=` comparison operators and the simple pattern matching `%` operator.
```
friends.#[last=="Murphy"].first >> "Dale"
friends.#[last=="Murphy"]#.first >> ["Dale","Jane"]
friends.#[age>45]#.last >> ["Craig","Murphy"]
friends.#[first%"D*"].last >> "Murphy"
```
## Result Type
GJSON supports the json types `string`, `number`, `bool`, and `null`.
Arrays and Objects are returned as their raw json types.
The `Result` type holds one of these:
```
bool, for JSON booleans
float64, for JSON numbers
string, for JSON string literals
nil, for JSON null
```
To directly access the value:
```go
result.Type // can be String, Number, True, False, Null, or JSON
result.Str // holds the string
result.Num // holds the float64 number
result.Raw // holds the raw json
result.Index // index of raw value in original json, zero means index unknown
```
There are a variety of handy functions that work on a result:
```go
result.Exists() bool
result.Value() interface{}
result.Int() int64
result.Uint() uint64
result.Float() float64
result.String() string
result.Bool() bool
result.Time() time.Time
result.Array() []gjson.Result
result.Map() map[string]gjson.Result
result.Get(path string) Result
result.ForEach(iterator func(key, value Result) bool)
result.Less(token Result, caseSensitive bool) bool
```
The `result.Value()` function returns an `interface{}` which requires type assertion and is one of the following Go types:
The `result.Array()` function returns back an array of values.
If the result represents a non-existent value, then an empty array will be returned.
If the result is not a JSON array, the return value will be an array containing one result.
```go
boolean >> bool
number >> float64
string >> string
null >> nil
array >> []interface{}
object >> map[string]interface{}
```
## Get nested array values
Suppose you want all the last names from the following json:
```json
{
"programmers": [
{
"firstName": "Janet",
"lastName": "McLaughlin",
}, {
"firstName": "Elliotte",
"lastName": "Hunter",
}, {
"firstName": "Jason",
"lastName": "Harold",
}
]
}
```
You would use the path "programmers.#.lastName" like such:
```go
result := gjson.Get(json, "programmers.#.lastName")
for _, name := range result.Array() {
println(name.String())
}
```
You can also query an object inside an array:
```go
name := gjson.Get(json, `programmers.#[lastName="Hunter"].firstName`)
println(name.String()) // prints "Elliotte"
```
## Iterate through an object or array
The `ForEach` function allows for quickly iterating through an object or array.
The key and value are passed to the iterator function for objects.
Only the value is passed for arrays.
Returning `false` from an iterator will stop iteration.
```go
result := gjson.Get(json, "programmers")
result.ForEach(func(key, value gjson.Result) bool {
println(value.String())
return true // keep iterating
})
```
## Simple Parse and Get
There's a `Parse(json)` function that will do a simple parse, and `result.Get(path)` that will search a result.
For example, all of these will return the same result:
```go
gjson.Parse(json).Get("name").Get("last")
gjson.Get(json, "name").Get("last")
gjson.Get(json, "name.last")
```
## Check for the existence of a value
Sometimes you just want to know if a value exists.
```go
value := gjson.Get(json, "name.last")
if !value.Exists() {
println("no last name")
} else {
println(value.String())
}
// Or as one step
if gjson.Get(json, "name.last").Exists() {
println("has a last name")
}
```
## Unmarshal to a map
To unmarshal to a `map[string]interface{}`:
```go
m, ok := gjson.Parse(json).Value().(map[string]interface{})
if !ok {
// not a map
}
```
## Working with Bytes
If your JSON is contained in a `[]byte` slice, there's the [GetBytes](https://godoc.org/github.com/tidwall/gjson#GetBytes) function. This is preferred over `Get(string(data), path)`.
```go
var json []byte = ...
result := gjson.GetBytes(json, path)
```
If you are using the `gjson.GetBytes(json, path)` function and you want to avoid converting `result.Raw` to a `[]byte`, then you can use this pattern:
```go
var json []byte = ...
result := gjson.GetBytes(json, path)
var raw []byte
if result.Index > 0 {
raw = json[result.Index:result.Index+len(result.Raw)]
} else {
raw = []byte(result.Raw)
}
```
This is a best-effort no allocation sub slice of the original json. This method utilizes the `result.Index` field, which is the position of the raw data in the original json. It's possible that the value of `result.Index` equals zero, in which case the `result.Raw` is converted to a `[]byte`.
## Get multiple values at once
The `GetMany` function can be used to get multiple values at the same time, and is optimized to scan over a JSON payload once.
```go
results := gjson.GetMany(json, "name.first", "name.last", "age")
```
The return value is a `[]Result`, which will always contain exactly the same number of items as the input paths.
## Performance
Benchmarks of GJSON alongside [encoding/json](https://golang.org/pkg/encoding/json/),
[ffjson](https://github.com/pquerna/ffjson),
[EasyJSON](https://github.com/mailru/easyjson),
[jsonparser](https://github.com/buger/jsonparser),
and [json-iterator](https://github.com/json-iterator/go)
```
BenchmarkGJSONGet-8 3000000 372 ns/op 0 B/op 0 allocs/op
BenchmarkGJSONUnmarshalMap-8 900000 4154 ns/op 1920 B/op 26 allocs/op
BenchmarkJSONUnmarshalMap-8 600000 9019 ns/op 3048 B/op 69 allocs/op
BenchmarkJSONDecoder-8 300000 14120 ns/op 4224 B/op 184 allocs/op
BenchmarkFFJSONLexer-8 1500000 3111 ns/op 896 B/op 8 allocs/op
BenchmarkEasyJSONLexer-8 3000000 887 ns/op 613 B/op 6 allocs/op
BenchmarkJSONParserGet-8 3000000 499 ns/op 21 B/op 0 allocs/op
BenchmarkJSONIterator-8 3000000 812 ns/op 544 B/op 9 allocs/op
```
Benchmarks for the `GetMany` function:
```
BenchmarkGJSONGetMany4Paths-8 4000000 303 ns/op 112 B/op 0 allocs/op
BenchmarkGJSONGetMany8Paths-8 8000000 208 ns/op 56 B/op 0 allocs/op
BenchmarkGJSONGetMany16Paths-8 16000000 156 ns/op 56 B/op 0 allocs/op
BenchmarkGJSONGetMany32Paths-8 32000000 127 ns/op 64 B/op 0 allocs/op
BenchmarkGJSONGetMany64Paths-8 64000000 117 ns/op 64 B/op 0 allocs/op
BenchmarkGJSONGetMany128Paths-8 128000000 109 ns/op 64 B/op 0 allocs/op
```
JSON document used:
```json
{
"widget": {
"debug": "on",
"window": {
"title": "Sample Konfabulator Widget",
"name": "main_window",
"width": 500,
"height": 500
},
"image": {
"src": "Images/Sun.png",
"hOffset": 250,
"vOffset": 250,
"alignment": "center"
},
"text": {
"data": "Click Here",
"size": 36,
"style": "bold",
"vOffset": 100,
"alignment": "center",
"onMouseUp": "sun1.opacity = (sun1.opacity / 100) * 90;"
}
}
}
```
Each operation was rotated though one of the following search paths:
```
widget.window.name
widget.image.hOffset
widget.text.onMouseUp
```
For the `GetMany` benchmarks these paths are used:
```
widget.window.name
widget.image.hOffset
widget.text.onMouseUp
widget.window.title
widget.image.alignment
widget.text.style
widget.window.height
widget.image.src
widget.text.data
widget.text.size
```
*These benchmarks were run on a MacBook Pro 15" 2.8 GHz Intel Core i7 using Go 1.8 and can be be found [here](https://github.com/tidwall/gjson-benchmarks).*
## Contact
Josh Baker [@tidwall](http://twitter.com/tidwall)
## License
GJSON source code is available under the MIT [License](/LICENSE).

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View file

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2016 Josh Baker
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,32 @@
Match
=====
<a href="https://travis-ci.org/tidwall/match"><img src="https://img.shields.io/travis/tidwall/match.svg?style=flat-square" alt="Build Status"></a>
<a href="https://godoc.org/github.com/tidwall/match"><img src="https://img.shields.io/badge/api-reference-blue.svg?style=flat-square" alt="GoDoc"></a>
Match is a very simple pattern matcher where '*' matches on any
number characters and '?' matches on any one character.
Installing
----------
```
go get -u github.com/tidwall/match
```
Example
-------
```go
match.Match("hello", "*llo")
match.Match("jello", "?ello")
match.Match("hello", "h*o")
```
Contact
-------
Josh Baker [@tidwall](http://twitter.com/tidwall)
License
-------
Redcon source code is available under the MIT [License](/LICENSE).

View file

@ -0,0 +1,192 @@
// Match provides a simple pattern matcher with unicode support.
package match
import "unicode/utf8"
// Match returns true if str matches pattern. This is a very
// simple wildcard match where '*' matches on any number characters
// and '?' matches on any one character.
// pattern:
// { term }
// term:
// '*' matches any sequence of non-Separator characters
// '?' matches any single non-Separator character
// c matches character c (c != '*', '?', '\\')
// '\\' c matches character c
//
func Match(str, pattern string) bool {
if pattern == "*" {
return true
}
return deepMatch(str, pattern)
}
func deepMatch(str, pattern string) bool {
for len(pattern) > 0 {
if pattern[0] > 0x7f {
return deepMatchRune(str, pattern)
}
switch pattern[0] {
default:
if len(str) == 0 {
return false
}
if str[0] > 0x7f {
return deepMatchRune(str, pattern)
}
if str[0] != pattern[0] {
return false
}
case '?':
if len(str) == 0 {
return false
}
case '*':
return deepMatch(str, pattern[1:]) ||
(len(str) > 0 && deepMatch(str[1:], pattern))
}
str = str[1:]
pattern = pattern[1:]
}
return len(str) == 0 && len(pattern) == 0
}
func deepMatchRune(str, pattern string) bool {
var sr, pr rune
var srsz, prsz int
// read the first rune ahead of time
if len(str) > 0 {
if str[0] > 0x7f {
sr, srsz = utf8.DecodeRuneInString(str)
} else {
sr, srsz = rune(str[0]), 1
}
} else {
sr, srsz = utf8.RuneError, 0
}
if len(pattern) > 0 {
if pattern[0] > 0x7f {
pr, prsz = utf8.DecodeRuneInString(pattern)
} else {
pr, prsz = rune(pattern[0]), 1
}
} else {
pr, prsz = utf8.RuneError, 0
}
// done reading
for pr != utf8.RuneError {
switch pr {
default:
if srsz == utf8.RuneError {
return false
}
if sr != pr {
return false
}
case '?':
if srsz == utf8.RuneError {
return false
}
case '*':
return deepMatchRune(str, pattern[prsz:]) ||
(srsz > 0 && deepMatchRune(str[srsz:], pattern))
}
str = str[srsz:]
pattern = pattern[prsz:]
// read the next runes
if len(str) > 0 {
if str[0] > 0x7f {
sr, srsz = utf8.DecodeRuneInString(str)
} else {
sr, srsz = rune(str[0]), 1
}
} else {
sr, srsz = utf8.RuneError, 0
}
if len(pattern) > 0 {
if pattern[0] > 0x7f {
pr, prsz = utf8.DecodeRuneInString(pattern)
} else {
pr, prsz = rune(pattern[0]), 1
}
} else {
pr, prsz = utf8.RuneError, 0
}
// done reading
}
return srsz == 0 && prsz == 0
}
var maxRuneBytes = func() []byte {
b := make([]byte, 4)
if utf8.EncodeRune(b, '\U0010FFFF') != 4 {
panic("invalid rune encoding")
}
return b
}()
// Allowable parses the pattern and determines the minimum and maximum allowable
// values that the pattern can represent.
// When the max cannot be determined, 'true' will be returned
// for infinite.
func Allowable(pattern string) (min, max string) {
if pattern == "" || pattern[0] == '*' {
return "", ""
}
minb := make([]byte, 0, len(pattern))
maxb := make([]byte, 0, len(pattern))
var wild bool
for i := 0; i < len(pattern); i++ {
if pattern[i] == '*' {
wild = true
break
}
if pattern[i] == '?' {
minb = append(minb, 0)
maxb = append(maxb, maxRuneBytes...)
} else {
minb = append(minb, pattern[i])
maxb = append(maxb, pattern[i])
}
}
if wild {
r, n := utf8.DecodeLastRune(maxb)
if r != utf8.RuneError {
if r < utf8.MaxRune {
r++
if r > 0x7f {
b := make([]byte, 4)
nn := utf8.EncodeRune(b, r)
maxb = append(maxb[:len(maxb)-n], b[:nn]...)
} else {
maxb = append(maxb[:len(maxb)-n], byte(r))
}
}
}
}
return string(minb), string(maxb)
/*
return
if wild {
r, n := utf8.DecodeLastRune(maxb)
if r != utf8.RuneError {
if r < utf8.MaxRune {
infinite = true
} else {
r++
if r > 0x7f {
b := make([]byte, 4)
nn := utf8.EncodeRune(b, r)
maxb = append(maxb[:len(maxb)-n], b[:nn]...)
} else {
maxb = append(maxb[:len(maxb)-n], byte(r))
}
}
}
}
return string(minb), string(maxb), infinite
*/
}

View file

@ -0,0 +1,408 @@
package match
import (
"fmt"
"math/rand"
"testing"
"time"
"unicode/utf8"
)
func TestMatch(t *testing.T) {
if !Match("hello world", "hello world") {
t.Fatal("fail")
}
if Match("hello world", "jello world") {
t.Fatal("fail")
}
if !Match("hello world", "hello*") {
t.Fatal("fail")
}
if Match("hello world", "jello*") {
t.Fatal("fail")
}
if !Match("hello world", "hello?world") {
t.Fatal("fail")
}
if Match("hello world", "jello?world") {
t.Fatal("fail")
}
if !Match("hello world", "he*o?world") {
t.Fatal("fail")
}
if !Match("hello world", "he*o?wor*") {
t.Fatal("fail")
}
if !Match("hello world", "he*o?*r*") {
t.Fatal("fail")
}
if !Match("的情况下解析一个", "*") {
t.Fatal("fail")
}
if !Match("的情况下解析一个", "*况下*") {
t.Fatal("fail")
}
if !Match("的情况下解析一个", "*况?*") {
t.Fatal("fail")
}
if !Match("的情况下解析一个", "的情况?解析一个") {
t.Fatal("fail")
}
}
// TestWildcardMatch - Tests validate the logic of wild card matching.
// `WildcardMatch` supports '*' and '?' wildcards.
// Sample usage: In resource matching for folder policy validation.
func TestWildcardMatch(t *testing.T) {
testCases := []struct {
pattern string
text string
matched bool
}{
// Test case - 1.
// Test case with pattern containing key name with a prefix. Should accept the same text without a "*".
{
pattern: "my-folder/oo*",
text: "my-folder/oo",
matched: true,
},
// Test case - 2.
// Test case with "*" at the end of the pattern.
{
pattern: "my-folder/In*",
text: "my-folder/India/Karnataka/",
matched: true,
},
// Test case - 3.
// Test case with prefixes shuffled.
// This should fail.
{
pattern: "my-folder/In*",
text: "my-folder/Karnataka/India/",
matched: false,
},
// Test case - 4.
// Test case with text expanded to the wildcards in the pattern.
{
pattern: "my-folder/In*/Ka*/Ban",
text: "my-folder/India/Karnataka/Ban",
matched: true,
},
// Test case - 5.
// Test case with the keyname part is repeated as prefix several times.
// This is valid.
{
pattern: "my-folder/In*/Ka*/Ban",
text: "my-folder/India/Karnataka/Ban/Ban/Ban/Ban/Ban",
matched: true,
},
// Test case - 6.
// Test case to validate that `*` can be expanded into multiple prefixes.
{
pattern: "my-folder/In*/Ka*/Ban",
text: "my-folder/India/Karnataka/Area1/Area2/Area3/Ban",
matched: true,
},
// Test case - 7.
// Test case to validate that `*` can be expanded into multiple prefixes.
{
pattern: "my-folder/In*/Ka*/Ban",
text: "my-folder/India/State1/State2/Karnataka/Area1/Area2/Area3/Ban",
matched: true,
},
// Test case - 8.
// Test case where the keyname part of the pattern is expanded in the text.
{
pattern: "my-folder/In*/Ka*/Ban",
text: "my-folder/India/Karnataka/Bangalore",
matched: false,
},
// Test case - 9.
// Test case with prefixes and wildcard expanded for all "*".
{
pattern: "my-folder/In*/Ka*/Ban*",
text: "my-folder/India/Karnataka/Bangalore",
matched: true,
},
// Test case - 10.
// Test case with keyname part being a wildcard in the pattern.
{pattern: "my-folder/*",
text: "my-folder/India",
matched: true,
},
// Test case - 11.
{
pattern: "my-folder/oo*",
text: "my-folder/odo",
matched: false,
},
// Test case with pattern containing wildcard '?'.
// Test case - 12.
// "my-folder?/" matches "my-folder1/", "my-folder2/", "my-folder3" etc...
// doesn't match "myfolder/".
{
pattern: "my-folder?/abc*",
text: "myfolder/abc",
matched: false,
},
// Test case - 13.
{
pattern: "my-folder?/abc*",
text: "my-folder1/abc",
matched: true,
},
// Test case - 14.
{
pattern: "my-?-folder/abc*",
text: "my--folder/abc",
matched: false,
},
// Test case - 15.
{
pattern: "my-?-folder/abc*",
text: "my-1-folder/abc",
matched: true,
},
// Test case - 16.
{
pattern: "my-?-folder/abc*",
text: "my-k-folder/abc",
matched: true,
},
// Test case - 17.
{
pattern: "my??folder/abc*",
text: "myfolder/abc",
matched: false,
},
// Test case - 18.
{
pattern: "my??folder/abc*",
text: "my4afolder/abc",
matched: true,
},
// Test case - 19.
{
pattern: "my-folder?abc*",
text: "my-folder/abc",
matched: true,
},
// Test case 20-21.
// '?' matches '/' too. (works with s3).
// This is because the namespace is considered flat.
// "abc?efg" matches both "abcdefg" and "abc/efg".
{
pattern: "my-folder/abc?efg",
text: "my-folder/abcdefg",
matched: true,
},
{
pattern: "my-folder/abc?efg",
text: "my-folder/abc/efg",
matched: true,
},
// Test case - 22.
{
pattern: "my-folder/abc????",
text: "my-folder/abc",
matched: false,
},
// Test case - 23.
{
pattern: "my-folder/abc????",
text: "my-folder/abcde",
matched: false,
},
// Test case - 24.
{
pattern: "my-folder/abc????",
text: "my-folder/abcdefg",
matched: true,
},
// Test case 25-26.
// test case with no '*'.
{
pattern: "my-folder/abc?",
text: "my-folder/abc",
matched: false,
},
{
pattern: "my-folder/abc?",
text: "my-folder/abcd",
matched: true,
},
{
pattern: "my-folder/abc?",
text: "my-folder/abcde",
matched: false,
},
// Test case 27.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mnop",
matched: false,
},
// Test case 28.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mnopqrst/mnopqr",
matched: true,
},
// Test case 29.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mnopqrst/mnopqrs",
matched: true,
},
// Test case 30.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mnop",
matched: false,
},
// Test case 31.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mnopq",
matched: true,
},
// Test case 32.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mnopqr",
matched: true,
},
// Test case 33.
{
pattern: "my-folder/mnop*?and",
text: "my-folder/mnopqand",
matched: true,
},
// Test case 34.
{
pattern: "my-folder/mnop*?and",
text: "my-folder/mnopand",
matched: false,
},
// Test case 35.
{
pattern: "my-folder/mnop*?and",
text: "my-folder/mnopqand",
matched: true,
},
// Test case 36.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mn",
matched: false,
},
// Test case 37.
{
pattern: "my-folder/mnop*?",
text: "my-folder/mnopqrst/mnopqrs",
matched: true,
},
// Test case 38.
{
pattern: "my-folder/mnop*??",
text: "my-folder/mnopqrst",
matched: true,
},
// Test case 39.
{
pattern: "my-folder/mnop*qrst",
text: "my-folder/mnopabcdegqrst",
matched: true,
},
// Test case 40.
{
pattern: "my-folder/mnop*?and",
text: "my-folder/mnopqand",
matched: true,
},
// Test case 41.
{
pattern: "my-folder/mnop*?and",
text: "my-folder/mnopand",
matched: false,
},
// Test case 42.
{
pattern: "my-folder/mnop*?and?",
text: "my-folder/mnopqanda",
matched: true,
},
// Test case 43.
{
pattern: "my-folder/mnop*?and",
text: "my-folder/mnopqanda",
matched: false,
},
// Test case 44.
{
pattern: "my-?-folder/abc*",
text: "my-folder/mnopqanda",
matched: false,
},
}
// Iterating over the test cases, call the function under test and asert the output.
for i, testCase := range testCases {
actualResult := Match(testCase.text, testCase.pattern)
if testCase.matched != actualResult {
t.Errorf("Test %d: Expected the result to be `%v`, but instead found it to be `%v`", i+1, testCase.matched, actualResult)
}
}
}
func TestRandomInput(t *testing.T) {
rand.Seed(time.Now().UnixNano())
b1 := make([]byte, 100)
b2 := make([]byte, 100)
for i := 0; i < 1000000; i++ {
if _, err := rand.Read(b1); err != nil {
t.Fatal(err)
}
if _, err := rand.Read(b2); err != nil {
t.Fatal(err)
}
Match(string(b1), string(b2))
}
}
func testAllowable(pattern, exmin, exmax string) error {
min, max := Allowable(pattern)
if min != exmin || max != exmax {
return fmt.Errorf("expected '%v'/'%v', got '%v'/'%v'",
exmin, exmax, min, max)
}
return nil
}
func TestAllowable(t *testing.T) {
if err := testAllowable("hell*", "hell", "helm"); err != nil {
t.Fatal(err)
}
if err := testAllowable("hell?", "hell"+string(0), "hell"+string(utf8.MaxRune)); err != nil {
t.Fatal(err)
}
if err := testAllowable("h解析ell*", "h解析ell", "h解析elm"); err != nil {
t.Fatal(err)
}
if err := testAllowable("h解*ell*", "h解", "h觤"); err != nil {
t.Fatal(err)
}
}
func BenchmarkAscii(t *testing.B) {
for i := 0; i < t.N; i++ {
if !Match("hello", "hello") {
t.Fatal("fail")
}
}
}
func BenchmarkUnicode(t *testing.B) {
for i := 0; i < t.N; i++ {
if !Match("h情llo", "h情llo") {
t.Fatal("fail")
}
}
}

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Josh Baker
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,278 @@
<p align="center">
<img
src="logo.png"
width="240" height="78" border="0" alt="SJSON">
<br>
<a href="https://travis-ci.org/tidwall/sjson"><img src="https://img.shields.io/travis/tidwall/sjson.svg?style=flat-square" alt="Build Status"></a>
<a href="https://godoc.org/github.com/tidwall/sjson"><img src="https://img.shields.io/badge/api-reference-blue.svg?style=flat-square" alt="GoDoc"></a>
</p>
<p align="center">set a json value quickly</a></p>
SJSON is a Go package that provides a [very fast](#performance) and simple way to set a value in a json document. The purpose for this library is to provide efficient json updating for the [SummitDB](https://github.com/tidwall/summitdb) project.
For quickly retrieving json values check out [GJSON](https://github.com/tidwall/gjson).
For a command line interface check out [JSONed](https://github.com/tidwall/jsoned).
Getting Started
===============
Installing
----------
To start using SJSON, install Go and run `go get`:
```sh
$ go get -u github.com/tidwall/sjson
```
This will retrieve the library.
Set a value
-----------
Set sets the value for the specified path.
A path is in dot syntax, such as "name.last" or "age".
This function expects that the json is well-formed and validated.
Invalid json will not panic, but it may return back unexpected results.
Invalid paths may return an error.
```go
package main
import "github.com/tidwall/sjson"
const json = `{"name":{"first":"Janet","last":"Prichard"},"age":47}`
func main() {
value, _ := sjson.Set(json, "name.last", "Anderson")
println(value)
}
```
This will print:
```json
{"name":{"first":"Janet","last":"Anderson"},"age":47}
```
Path syntax
-----------
A path is a series of keys separated by a dot.
The dot and colon characters can be escaped with '\'.
```json
{
"name": {"first": "Tom", "last": "Anderson"},
"age":37,
"children": ["Sara","Alex","Jack"],
"fav.movie": "Deer Hunter",
"friends": [
{"first": "James", "last": "Murphy"},
{"first": "Roger", "last": "Craig"}
]
}
```
```
"name.last" >> "Anderson"
"age" >> 37
"children.1" >> "Alex"
"friends.1.last" >> "Craig"
```
The `-1` key can be used to append a value to an existing array:
```
"children.-1" >> appends a new value to the end of the children array
```
Normally number keys are used to modify arrays, but it's possible to force a numeric object key by using the colon character:
```json
{
"users":{
"2313":{"name":"Sara"},
"7839":{"name":"Andy"}
}
}
```
A colon path would look like:
```
"users.:2313.name" >> "Sara"
```
Supported types
---------------
Pretty much any type is supported:
```go
sjson.Set(`{"key":true}`, "key", nil)
sjson.Set(`{"key":true}`, "key", false)
sjson.Set(`{"key":true}`, "key", 1)
sjson.Set(`{"key":true}`, "key", 10.5)
sjson.Set(`{"key":true}`, "key", "hello")
sjson.Set(`{"key":true}`, "key", map[string]interface{}{"hello":"world"})
```
When a type is not recognized, SJSON will fallback to the `encoding/json` Marshaller.
Examples
--------
Set a value from empty document:
```go
value, _ := sjson.Set("", "name", "Tom")
println(value)
// Output:
// {"name":"Tom"}
```
Set a nested value from empty document:
```go
value, _ := sjson.Set("", "name.last", "Anderson")
println(value)
// Output:
// {"name":{"last":"Anderson"}}
```
Set a new value:
```go
value, _ := sjson.Set(`{"name":{"last":"Anderson"}}`, "name.first", "Sara")
println(value)
// Output:
// {"name":{"first":"Sara","last":"Anderson"}}
```
Update an existing value:
```go
value, _ := sjson.Set(`{"name":{"last":"Anderson"}}`, "name.last", "Smith")
println(value)
// Output:
// {"name":{"last":"Smith"}}
```
Set a new array value:
```go
value, _ := sjson.Set(`{"friends":["Andy","Carol"]}`, "friends.2", "Sara")
println(value)
// Output:
// {"friends":["Andy","Carol","Sara"]
```
Append an array value by using the `-1` key in a path:
```go
value, _ := sjson.Set(`{"friends":["Andy","Carol"]}`, "friends.-1", "Sara")
println(value)
// Output:
// {"friends":["Andy","Carol","Sara"]
```
Append an array value that is past the end:
```go
value, _ := sjson.Set(`{"friends":["Andy","Carol"]}`, "friends.4", "Sara")
println(value)
// Output:
// {"friends":["Andy","Carol",null,null,"Sara"]
```
Delete a value:
```go
value, _ := sjson.Delete(`{"name":{"first":"Sara","last":"Anderson"}}`, "name.first")
println(value)
// Output:
// {"name":{"last":"Anderson"}}
```
Delete an array value:
```go
value, _ := sjson.Delete(`{"friends":["Andy","Carol"]}`, "friends.1")
println(value)
// Output:
// {"friends":["Andy"]}
```
Delete the last array value:
```go
value, _ := sjson.Delete(`{"friends":["Andy","Carol"]}`, "friends.-1")
println(value)
// Output:
// {"friends":["Andy"]}
```
## Performance
Benchmarks of SJSON alongside [encoding/json](https://golang.org/pkg/encoding/json/),
[ffjson](https://github.com/pquerna/ffjson),
[EasyJSON](https://github.com/mailru/easyjson),
and [Gabs](https://github.com/Jeffail/gabs)
```
Benchmark_SJSON-8 3000000 805 ns/op 1077 B/op 3 allocs/op
Benchmark_SJSON_ReplaceInPlace-8 3000000 449 ns/op 0 B/op 0 allocs/op
Benchmark_JSON_Map-8 300000 21236 ns/op 6392 B/op 150 allocs/op
Benchmark_JSON_Struct-8 300000 14691 ns/op 1789 B/op 24 allocs/op
Benchmark_Gabs-8 300000 21311 ns/op 6752 B/op 150 allocs/op
Benchmark_FFJSON-8 300000 17673 ns/op 3589 B/op 47 allocs/op
Benchmark_EasyJSON-8 1500000 3119 ns/op 1061 B/op 13 allocs/op
```
JSON document used:
```json
{
"widget": {
"debug": "on",
"window": {
"title": "Sample Konfabulator Widget",
"name": "main_window",
"width": 500,
"height": 500
},
"image": {
"src": "Images/Sun.png",
"hOffset": 250,
"vOffset": 250,
"alignment": "center"
},
"text": {
"data": "Click Here",
"size": 36,
"style": "bold",
"vOffset": 100,
"alignment": "center",
"onMouseUp": "sun1.opacity = (sun1.opacity / 100) * 90;"
}
}
}
```
Each operation was rotated though one of the following search paths:
```
widget.window.name
widget.image.hOffset
widget.text.onMouseUp
```
*These benchmarks were run on a MacBook Pro 15" 2.8 GHz Intel Core i7 using Go 1.7.*
## Contact
Josh Baker [@tidwall](http://twitter.com/tidwall)
## License
SJSON source code is available under the MIT [License](/LICENSE).

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View file

@ -0,0 +1,653 @@
// Package sjson provides setting json values.
package sjson
import (
jsongo "encoding/json"
"reflect"
"strconv"
"unsafe"
"github.com/tidwall/gjson"
)
type errorType struct {
msg string
}
func (err *errorType) Error() string {
return err.msg
}
// Options represents additional options for the Set and Delete functions.
type Options struct {
// Optimistic is a hint that the value likely exists which
// allows for the sjson to perform a fast-track search and replace.
Optimistic bool
// ReplaceInPlace is a hint to replace the input json rather than
// allocate a new json byte slice. When this field is specified
// the input json will not longer be valid and it should not be used
// In the case when the destination slice doesn't have enough free
// bytes to replace the data in place, a new bytes slice will be
// created under the hood.
// The Optimistic flag must be set to true and the input must be a
// byte slice in order to use this field.
ReplaceInPlace bool
}
type pathResult struct {
part string // current key part
path string // remaining path
force bool // force a string key
more bool // there is more path to parse
}
func parsePath(path string) (pathResult, error) {
var r pathResult
if len(path) > 0 && path[0] == ':' {
r.force = true
path = path[1:]
}
for i := 0; i < len(path); i++ {
if path[i] == '.' {
r.part = path[:i]
r.path = path[i+1:]
r.more = true
return r, nil
}
if path[i] == '*' || path[i] == '?' {
return r, &errorType{"wildcard characters not allowed in path"}
} else if path[i] == '#' {
return r, &errorType{"array access character not allowed in path"}
}
if path[i] == '\\' {
// go into escape mode. this is a slower path that
// strips off the escape character from the part.
epart := []byte(path[:i])
i++
if i < len(path) {
epart = append(epart, path[i])
i++
for ; i < len(path); i++ {
if path[i] == '\\' {
i++
if i < len(path) {
epart = append(epart, path[i])
}
continue
} else if path[i] == '.' {
r.part = string(epart)
r.path = path[i+1:]
r.more = true
return r, nil
} else if path[i] == '*' || path[i] == '?' {
return r, &errorType{
"wildcard characters not allowed in path"}
} else if path[i] == '#' {
return r, &errorType{
"array access character not allowed in path"}
}
epart = append(epart, path[i])
}
}
// append the last part
r.part = string(epart)
return r, nil
}
}
r.part = path
return r, nil
}
func mustMarshalString(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > 0x7f || s[i] == '"' {
return true
}
}
return false
}
// appendStringify makes a json string and appends to buf.
func appendStringify(buf []byte, s string) []byte {
if mustMarshalString(s) {
b, _ := jsongo.Marshal(s)
return append(buf, b...)
}
buf = append(buf, '"')
buf = append(buf, s...)
buf = append(buf, '"')
return buf
}
// appendBuild builds a json block from a json path.
func appendBuild(buf []byte, array bool, paths []pathResult, raw string,
stringify bool) []byte {
if !array {
buf = appendStringify(buf, paths[0].part)
buf = append(buf, ':')
}
if len(paths) > 1 {
n, numeric := atoui(paths[1])
if numeric || (!paths[1].force && paths[1].part == "-1") {
buf = append(buf, '[')
buf = appendRepeat(buf, "null,", n)
buf = appendBuild(buf, true, paths[1:], raw, stringify)
buf = append(buf, ']')
} else {
buf = append(buf, '{')
buf = appendBuild(buf, false, paths[1:], raw, stringify)
buf = append(buf, '}')
}
} else {
if stringify {
buf = appendStringify(buf, raw)
} else {
buf = append(buf, raw...)
}
}
return buf
}
// atoui does a rip conversion of string -> unigned int.
func atoui(r pathResult) (n int, ok bool) {
if r.force {
return 0, false
}
for i := 0; i < len(r.part); i++ {
if r.part[i] < '0' || r.part[i] > '9' {
return 0, false
}
n = n*10 + int(r.part[i]-'0')
}
return n, true
}
// appendRepeat repeats string "n" times and appends to buf.
func appendRepeat(buf []byte, s string, n int) []byte {
for i := 0; i < n; i++ {
buf = append(buf, s...)
}
return buf
}
// trim does a rip trim
func trim(s string) string {
for len(s) > 0 {
if s[0] <= ' ' {
s = s[1:]
continue
}
break
}
for len(s) > 0 {
if s[len(s)-1] <= ' ' {
s = s[:len(s)-1]
continue
}
break
}
return s
}
// deleteTailItem deletes the previous key or comma.
func deleteTailItem(buf []byte) ([]byte, bool) {
loop:
for i := len(buf) - 1; i >= 0; i-- {
// look for either a ',',':','['
switch buf[i] {
case '[':
return buf, true
case ',':
return buf[:i], false
case ':':
// delete tail string
i--
for ; i >= 0; i-- {
if buf[i] == '"' {
i--
for ; i >= 0; i-- {
if buf[i] == '"' {
i--
if i >= 0 && i == '\\' {
i--
continue
}
for ; i >= 0; i-- {
// look for either a ',','{'
switch buf[i] {
case '{':
return buf[:i+1], true
case ',':
return buf[:i], false
}
}
}
}
break
}
}
break loop
}
}
return buf, false
}
var errNoChange = &errorType{"no change"}
func appendRawPaths(buf []byte, jstr string, paths []pathResult, raw string,
stringify, del bool) ([]byte, error) {
var err error
var res gjson.Result
var found bool
if del {
if paths[0].part == "-1" && !paths[0].force {
res = gjson.Get(jstr, "#")
if res.Int() > 0 {
res = gjson.Get(jstr, strconv.FormatInt(int64(res.Int()-1), 10))
found = true
}
}
}
if !found {
res = gjson.Get(jstr, paths[0].part)
}
if res.Index > 0 {
if len(paths) > 1 {
buf = append(buf, jstr[:res.Index]...)
buf, err = appendRawPaths(buf, res.Raw, paths[1:], raw,
stringify, del)
if err != nil {
return nil, err
}
buf = append(buf, jstr[res.Index+len(res.Raw):]...)
return buf, nil
}
buf = append(buf, jstr[:res.Index]...)
var exidx int // additional forward stripping
if del {
var delNextComma bool
buf, delNextComma = deleteTailItem(buf)
if delNextComma {
i, j := res.Index+len(res.Raw), 0
for ; i < len(jstr); i, j = i+1, j+1 {
if jstr[i] <= ' ' {
continue
}
if jstr[i] == ',' {
exidx = j + 1
}
break
}
}
} else {
if stringify {
buf = appendStringify(buf, raw)
} else {
buf = append(buf, raw...)
}
}
buf = append(buf, jstr[res.Index+len(res.Raw)+exidx:]...)
return buf, nil
}
if del {
return nil, errNoChange
}
n, numeric := atoui(paths[0])
isempty := true
for i := 0; i < len(jstr); i++ {
if jstr[i] > ' ' {
isempty = false
break
}
}
if isempty {
if numeric {
jstr = "[]"
} else {
jstr = "{}"
}
}
jsres := gjson.Parse(jstr)
if jsres.Type != gjson.JSON {
if numeric {
jstr = "[]"
} else {
jstr = "{}"
}
jsres = gjson.Parse(jstr)
}
var comma bool
for i := 1; i < len(jsres.Raw); i++ {
if jsres.Raw[i] <= ' ' {
continue
}
if jsres.Raw[i] == '}' || jsres.Raw[i] == ']' {
break
}
comma = true
break
}
switch jsres.Raw[0] {
default:
return nil, &errorType{"json must be an object or array"}
case '{':
buf = append(buf, '{')
buf = appendBuild(buf, false, paths, raw, stringify)
if comma {
buf = append(buf, ',')
}
buf = append(buf, jsres.Raw[1:]...)
return buf, nil
case '[':
var appendit bool
if !numeric {
if paths[0].part == "-1" && !paths[0].force {
appendit = true
} else {
return nil, &errorType{
"cannot set array element for non-numeric key '" +
paths[0].part + "'"}
}
}
if appendit {
njson := trim(jsres.Raw)
if njson[len(njson)-1] == ']' {
njson = njson[:len(njson)-1]
}
buf = append(buf, njson...)
if comma {
buf = append(buf, ',')
}
buf = appendBuild(buf, true, paths, raw, stringify)
buf = append(buf, ']')
return buf, nil
}
buf = append(buf, '[')
ress := jsres.Array()
for i := 0; i < len(ress); i++ {
if i > 0 {
buf = append(buf, ',')
}
buf = append(buf, ress[i].Raw...)
}
if len(ress) == 0 {
buf = appendRepeat(buf, "null,", n-len(ress))
} else {
buf = appendRepeat(buf, ",null", n-len(ress))
if comma {
buf = append(buf, ',')
}
}
buf = appendBuild(buf, true, paths, raw, stringify)
buf = append(buf, ']')
return buf, nil
}
}
func isOptimisticPath(path string) bool {
for i := 0; i < len(path); i++ {
if path[i] < '.' || path[i] > 'z' {
return false
}
if path[i] > '9' && path[i] < 'A' {
return false
}
if path[i] > 'z' {
return false
}
}
return true
}
func set(jstr, path, raw string,
stringify, del, optimistic, inplace bool) ([]byte, error) {
if path == "" {
return nil, &errorType{"path cannot be empty"}
}
if !del && optimistic && isOptimisticPath(path) {
res := gjson.Get(jstr, path)
if res.Exists() && res.Index > 0 {
sz := len(jstr) - len(res.Raw) + len(raw)
if stringify {
sz += 2
}
if inplace && sz <= len(jstr) {
if !stringify || !mustMarshalString(raw) {
jsonh := *(*reflect.StringHeader)(unsafe.Pointer(&jstr))
jsonbh := reflect.SliceHeader{
Data: jsonh.Data, Len: jsonh.Len, Cap: jsonh.Len}
jbytes := *(*[]byte)(unsafe.Pointer(&jsonbh))
if stringify {
jbytes[res.Index] = '"'
copy(jbytes[res.Index+1:], []byte(raw))
jbytes[res.Index+1+len(raw)] = '"'
copy(jbytes[res.Index+1+len(raw)+1:],
jbytes[res.Index+len(res.Raw):])
} else {
copy(jbytes[res.Index:], []byte(raw))
copy(jbytes[res.Index+len(raw):],
jbytes[res.Index+len(res.Raw):])
}
return jbytes[:sz], nil
}
return nil, nil
}
buf := make([]byte, 0, sz)
buf = append(buf, jstr[:res.Index]...)
if stringify {
buf = appendStringify(buf, raw)
} else {
buf = append(buf, raw...)
}
buf = append(buf, jstr[res.Index+len(res.Raw):]...)
return buf, nil
}
}
// parse the path, make sure that it does not contain invalid characters
// such as '#', '?', '*'
paths := make([]pathResult, 0, 4)
r, err := parsePath(path)
if err != nil {
return nil, err
}
paths = append(paths, r)
for r.more {
if r, err = parsePath(r.path); err != nil {
return nil, err
}
paths = append(paths, r)
}
njson, err := appendRawPaths(nil, jstr, paths, raw, stringify, del)
if err != nil {
return nil, err
}
return njson, nil
}
// Set sets a json value for the specified path.
// A path is in dot syntax, such as "name.last" or "age".
// This function expects that the json is well-formed, and does not validate.
// Invalid json will not panic, but it may return back unexpected results.
// An error is returned if the path is not valid.
//
// A path is a series of keys separated by a dot.
//
// {
// "name": {"first": "Tom", "last": "Anderson"},
// "age":37,
// "children": ["Sara","Alex","Jack"],
// "friends": [
// {"first": "James", "last": "Murphy"},
// {"first": "Roger", "last": "Craig"}
// ]
// }
// "name.last" >> "Anderson"
// "age" >> 37
// "children.1" >> "Alex"
//
func Set(json, path string, value interface{}) (string, error) {
return SetOptions(json, path, value, nil)
}
// SetOptions sets a json value for the specified path with options.
// A path is in dot syntax, such as "name.last" or "age".
// This function expects that the json is well-formed, and does not validate.
// Invalid json will not panic, but it may return back unexpected results.
// An error is returned if the path is not valid.
func SetOptions(json, path string, value interface{},
opts *Options) (string, error) {
if opts != nil {
if opts.ReplaceInPlace {
// it's not safe to replace bytes in-place for strings
// copy the Options and set options.ReplaceInPlace to false.
nopts := *opts
opts = &nopts
opts.ReplaceInPlace = false
}
}
jsonh := *(*reflect.StringHeader)(unsafe.Pointer(&json))
jsonbh := reflect.SliceHeader{Data: jsonh.Data, Len: jsonh.Len}
jsonb := *(*[]byte)(unsafe.Pointer(&jsonbh))
res, err := SetBytesOptions(jsonb, path, value, opts)
return string(res), err
}
// SetBytes sets a json value for the specified path.
// If working with bytes, this method preferred over
// Set(string(data), path, value)
func SetBytes(json []byte, path string, value interface{}) ([]byte, error) {
return SetBytesOptions(json, path, value, nil)
}
// SetBytesOptions sets a json value for the specified path with options.
// If working with bytes, this method preferred over
// SetOptions(string(data), path, value)
func SetBytesOptions(json []byte, path string, value interface{},
opts *Options) ([]byte, error) {
var optimistic, inplace bool
if opts != nil {
optimistic = opts.Optimistic
inplace = opts.ReplaceInPlace
}
jstr := *(*string)(unsafe.Pointer(&json))
var res []byte
var err error
switch v := value.(type) {
default:
b, err := jsongo.Marshal(value)
if err != nil {
return nil, err
}
raw := *(*string)(unsafe.Pointer(&b))
res, err = set(jstr, path, raw, false, false, optimistic, inplace)
case dtype:
res, err = set(jstr, path, "", false, true, optimistic, inplace)
case string:
res, err = set(jstr, path, v, true, false, optimistic, inplace)
case []byte:
raw := *(*string)(unsafe.Pointer(&v))
res, err = set(jstr, path, raw, true, false, optimistic, inplace)
case bool:
if v {
res, err = set(jstr, path, "true", false, false, optimistic, inplace)
} else {
res, err = set(jstr, path, "false", false, false, optimistic, inplace)
}
case int8:
res, err = set(jstr, path, strconv.FormatInt(int64(v), 10),
false, false, optimistic, inplace)
case int16:
res, err = set(jstr, path, strconv.FormatInt(int64(v), 10),
false, false, optimistic, inplace)
case int32:
res, err = set(jstr, path, strconv.FormatInt(int64(v), 10),
false, false, optimistic, inplace)
case int64:
res, err = set(jstr, path, strconv.FormatInt(int64(v), 10),
false, false, optimistic, inplace)
case uint8:
res, err = set(jstr, path, strconv.FormatUint(uint64(v), 10),
false, false, optimistic, inplace)
case uint16:
res, err = set(jstr, path, strconv.FormatUint(uint64(v), 10),
false, false, optimistic, inplace)
case uint32:
res, err = set(jstr, path, strconv.FormatUint(uint64(v), 10),
false, false, optimistic, inplace)
case uint64:
res, err = set(jstr, path, strconv.FormatUint(uint64(v), 10),
false, false, optimistic, inplace)
case float32:
res, err = set(jstr, path, strconv.FormatFloat(float64(v), 'f', -1, 64),
false, false, optimistic, inplace)
case float64:
res, err = set(jstr, path, strconv.FormatFloat(float64(v), 'f', -1, 64),
false, false, optimistic, inplace)
}
if err == errNoChange {
return json, nil
}
return res, err
}
// SetRaw sets a raw json value for the specified path.
// This function works the same as Set except that the value is set as a
// raw block of json. This allows for setting premarshalled json objects.
func SetRaw(json, path, value string) (string, error) {
return SetRawOptions(json, path, value, nil)
}
// SetRawOptions sets a raw json value for the specified path with options.
// This furnction works the same as SetOptions except that the value is set
// as a raw block of json. This allows for setting premarshalled json objects.
func SetRawOptions(json, path, value string, opts *Options) (string, error) {
var optimistic bool
if opts != nil {
optimistic = opts.Optimistic
}
res, err := set(json, path, value, false, false, optimistic, false)
if err == errNoChange {
return json, nil
}
return string(res), err
}
// SetRawBytes sets a raw json value for the specified path.
// If working with bytes, this method preferred over
// SetRaw(string(data), path, value)
func SetRawBytes(json []byte, path string, value []byte) ([]byte, error) {
return SetRawBytesOptions(json, path, value, nil)
}
// SetRawBytesOptions sets a raw json value for the specified path with options.
// If working with bytes, this method preferred over
// SetRawOptions(string(data), path, value, opts)
func SetRawBytesOptions(json []byte, path string, value []byte,
opts *Options) ([]byte, error) {
jstr := *(*string)(unsafe.Pointer(&json))
vstr := *(*string)(unsafe.Pointer(&value))
var optimistic, inplace bool
if opts != nil {
optimistic = opts.Optimistic
inplace = opts.ReplaceInPlace
}
res, err := set(jstr, path, vstr, false, false, optimistic, inplace)
if err == errNoChange {
return json, nil
}
return res, err
}
type dtype struct{}
// Delete deletes a value from json for the specified path.
func Delete(json, path string) (string, error) {
return Set(json, path, dtype{})
}
// DeleteBytes deletes a value from json for the specified path.
func DeleteBytes(json []byte, path string) ([]byte, error) {
return SetBytes(json, path, dtype{})
}

File diff suppressed because it is too large Load diff