diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go index c7547ca34..40901bbf5 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/register.go @@ -4,7 +4,6 @@ import ( "context" "crypto/hmac" "crypto/sha1" - "encoding/hex" "errors" "fmt" "net/http" @@ -46,9 +45,9 @@ type registerRequest struct { } type authDict struct { - Type authtypes.LoginType `json:"type"` - Session string `json:"session"` - Mac string `json:"mac"` + Type authtypes.LoginType `json:"type"` + Session string `json:"session"` + Mac gomatrixserverlib.HexString `json:"mac"` // TODO: Lots of custom keys depending on the type } @@ -72,7 +71,7 @@ type legacyRegisterRequest struct { Username string `json:"user"` Admin bool `json:"admin"` Type authtypes.LoginType `json:"type"` - Mac string `json:"mac"` + Mac gomatrixserverlib.HexString `json:"mac"` } func newUserInteractiveResponse(sessionID string, fs []authFlow) userInteractiveResponse { @@ -301,7 +300,8 @@ func completeRegistration( func isValidMacLogin( username, password string, isAdmin bool, - givenMacStr, sharedSecret string, + givenMac []byte, + sharedSecret string, ) (bool, error) { // Double check that username/passowrd don't contain the HMAC delimiters. We should have // already checked this. @@ -328,10 +328,5 @@ func isValidMacLogin( } expectedMAC := mac.Sum(nil) - givenMac, err := hex.DecodeString(givenMacStr) - if err != nil { - return false, err - } - return hmac.Equal(givenMac, expectedMAC), nil } diff --git a/vendor/manifest b/vendor/manifest index e30b2f2e7..9200690e8 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -116,7 +116,7 @@ { "importpath": "github.com/matrix-org/gomatrixserverlib", "repository": "https://github.com/matrix-org/gomatrixserverlib", - "revision": "40b35e1c997fc7e35342aeb39187ff6bf3e10b2e", + "revision": "ce6f4766251e31487906dfaaebd7d7cfea147252", "branch": "master" }, { diff --git a/vendor/src/github.com/matrix-org/gomatrixserverlib/hex_string.go b/vendor/src/github.com/matrix-org/gomatrixserverlib/hex_string.go new file mode 100644 index 000000000..883307d53 --- /dev/null +++ b/vendor/src/github.com/matrix-org/gomatrixserverlib/hex_string.go @@ -0,0 +1,44 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gomatrixserverlib + +import ( + "encoding/hex" + "encoding/json" +) + +// A HexString is a string of bytes that are hex encoded when used in JSON. +// The bytes encoded using hex when marshalled as JSON. +// When the bytes are unmarshalled from JSON they are decoded from hex. +type HexString []byte + +// MarshalJSON encodes the bytes as hex and then encodes the hex as a JSON string. +// This takes a value receiver so that maps and slices of HexString encode correctly. +func (h HexString) MarshalJSON() ([]byte, error) { + return json.Marshal(hex.EncodeToString(h)) +} + +// UnmarshalJSON decodes a JSON string and then decodes the resulting hex. +// This takes a pointer receiver because it needs to write the result of decoding. +func (h *HexString) UnmarshalJSON(raw []byte) (err error) { + var str string + if err = json.Unmarshal(raw, &str); err != nil { + return + } + + *h, err = hex.DecodeString(str) + return +} diff --git a/vendor/src/github.com/matrix-org/gomatrixserverlib/hex_string_test.go b/vendor/src/github.com/matrix-org/gomatrixserverlib/hex_string_test.go new file mode 100644 index 000000000..6cb048622 --- /dev/null +++ b/vendor/src/github.com/matrix-org/gomatrixserverlib/hex_string_test.go @@ -0,0 +1,82 @@ +/* Copyright 2016-2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gomatrixserverlib + +import ( + "encoding/json" + "testing" +) + +func TestMarshalHex(t *testing.T) { + input := HexString("this\xffis\xffa\xfftest") + want := `"74686973ff6973ff61ff74657374"` + got, err := json.Marshal(input) + if err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Fatalf("json.Marshal(HexString(%q)): wanted %q got %q", string(input), want, string(got)) + } +} + +func TestUnmarshalHex(t *testing.T) { + input := []byte(`"74686973ff6973ff61ff74657374"`) + want := "this\xffis\xffa\xfftest" + var got HexString + err := json.Unmarshal(input, &got) + if err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Fatalf("json.Unmarshal(%q): wanted %q got %q", string(input), want, string(got)) + } +} + +func TestMarshalHexStruct(t *testing.T) { + input := struct{ Value HexString }{HexString("this\xffis\xffa\xfftest")} + want := `{"Value":"74686973ff6973ff61ff74657374"}` + got, err := json.Marshal(input) + if err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) + } +} + +func TestMarshalHexMap(t *testing.T) { + input := map[string]HexString{"Value": HexString("this\xffis\xffa\xfftest")} + want := `{"Value":"74686973ff6973ff61ff74657374"}` + got, err := json.Marshal(input) + if err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) + } +} + +func TestMarshalHexSlice(t *testing.T) { + input := []HexString{HexString("this\xffis\xffa\xfftest")} + want := `["74686973ff6973ff61ff74657374"]` + got, err := json.Marshal(input) + if err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) + } +}