Merge branch 'master' into kegan/http-auth

This commit is contained in:
Kegan Dougal 2017-02-24 09:31:56 +00:00
commit cd572a4c85
14 changed files with 567 additions and 35 deletions

View file

@ -1,6 +1,10 @@
// Package api provides the types that are used to communicate with the roomserver.
package api
import (
"encoding/json"
)
const (
// KindOutlier event fall outside the contiguous event graph.
// We do not have the state for these events.
@ -36,7 +40,61 @@ type InputRoomEvent struct {
// For example many matrix events forget to reference the m.room.create event even though it is needed for auth.
// (since synapse allows this to happen we have to allow it as well.)
AuthEventIDs []string
// Whether the state is supplied as a list of event IDs or whether it
// should be derived from the state at the previous events.
HasState bool
// Optional list of state event IDs forming the state before this event.
// These state events must have already been persisted.
// These are only used if HasState is true.
// The list can be empty, for example when storing the first event in a room.
StateEventIDs []string
}
// UnmarshalJSON implements json.Unmarshaller
func (ire *InputRoomEvent) UnmarshalJSON(data []byte) error {
// Create a struct rather than unmarshalling directly into the InputRoomEvent
// so that we can use json.RawMessage.
// We use json.RawMessage so that the event JSON is sent as JSON rather than
// being base64 encoded which is the default for []byte.
var content struct {
Kind int
Event *json.RawMessage
AuthEventIDs []string
StateEventIDs []string
HasState bool
}
if err := json.Unmarshal(data, &content); err != nil {
return err
}
ire.Kind = content.Kind
ire.AuthEventIDs = content.AuthEventIDs
ire.StateEventIDs = content.StateEventIDs
ire.HasState = content.HasState
if content.Event != nil {
ire.Event = []byte(*content.Event)
}
return nil
}
// MarshalJSON implements json.Marshaller
func (ire InputRoomEvent) MarshalJSON() ([]byte, error) {
// Create a struct rather than marshalling directly from the InputRoomEvent
// so that we can use json.RawMessage.
// We use json.RawMessage so that the event JSON is sent as JSON rather than
// being base64 encoded which is the default for []byte.
event := json.RawMessage(ire.Event)
content := struct {
Kind int
Event *json.RawMessage
AuthEventIDs []string
StateEventIDs []string
HasState bool
}{
Kind: ire.Kind,
AuthEventIDs: ire.AuthEventIDs,
StateEventIDs: ire.StateEventIDs,
Event: &event,
HasState: ire.HasState,
}
return json.Marshal(&content)
}

View file

@ -69,7 +69,7 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
if input.StateEventIDs != nil {
if input.HasState {
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
entries, err := db.StateEntriesForEventIDs(input.StateEventIDs)
@ -89,6 +89,11 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
}
if input.Kind == api.KindBackfill {
// Backfill is not implemented.
panic("Not implemented")
}
// Update the extremities of the event graph for the room
if err := updateLatestEvents(db, roomNID, stateAtEvent, event); err != nil {
return err
@ -102,5 +107,5 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
// - The event itself
// - The visiblity of the event, i.e. who is allowed to see the event.
// - The changes to the current state of the room.
panic("Not implemented")
return nil
}

View file

@ -58,15 +58,23 @@ func doUpdateLatestEvents(
}
// Check if this event references any of the latest events in the room.
var alreadyInLatest bool
var newLatest []types.StateAtEventAndReference
for _, l := range oldLatest {
keep := true
for _, prevEvent := range prevEvents {
if l.EventID == prevEvent.EventID && bytes.Compare(l.EventSHA256, prevEvent.EventSHA256) == 0 {
// This event can be removed from the latest events cause we've found an event that references it.
// (If an event is referenced by another event then it can't be one of the latest events in the room
// because we have an event that comes after it)
continue
keep = false
break
}
}
if l.EventNID == stateAtEvent.EventNID {
alreadyInLatest = true
}
if keep {
// Keep the event in the latest events.
newLatest = append(newLatest, l)
}
@ -79,8 +87,9 @@ func doUpdateLatestEvents(
return err
}
if !alreadyReferenced {
// This event is not referenced by any of the events in the room.
if !alreadyReferenced && !alreadyInLatest {
// This event is not referenced by any of the events in the room
// and the event is not already in the latest events.
// Add it to the latest events
newLatest = append(newLatest, types.StateAtEventAndReference{
StateAtEvent: stateAtEvent,

View file

@ -31,8 +31,7 @@ INSERT INTO event_state_keys (event_state_key_nid, event_state_key) VALUES
const insertEventStateKeyNIDSQL = "" +
"INSERT INTO event_state_keys (event_state_key) VALUES ($1)" +
" ON CONFLICT ON CONSTRAINT event_state_key_unique" +
" DO UPDATE SET event_state_key = $1" +
" RETURNING (event_state_key_nid)"
" DO NOTHING RETURNING (event_state_key_nid)"
const selectEventStateKeyNIDSQL = "" +
"SELECT event_state_key_nid FROM event_state_keys WHERE event_state_key = $1"

View file

@ -50,14 +50,18 @@ INSERT INTO event_types (event_type_nid, event_type) VALUES
// In that case the ID will be assigned using the next value from the sequence.
// We use `RETURNING` to tell postgres to return the assigned ID.
// But it's possible that the type was added in a query that raced with us.
// This will result in a conflict on the event_type_unique constraint.
// We peform a update that does nothing rather that doing nothing at all because
// postgres won't return anything unless we touch a row in the table.
// This will result in a conflict on the event_type_unique constraint, in this
// case we do nothing. Postgresql won't return a row in that case so we rely on
// the caller catching the sql.ErrNoRows error and running a select to get the row.
// We could get postgresql to return the row on a conflict by updating the row
// but it doesn't seem like a good idea to modify the rows just to make postgresql
// return it. Modifying the rows will cause postgres to assign a new tuple for the
// row even though the data doesn't change resulting in unncesssary modifications
// to the indexes.
const insertEventTypeNIDSQL = "" +
"INSERT INTO event_types (event_type) VALUES ($1)" +
" ON CONFLICT ON CONSTRAINT event_type_unique" +
" DO UPDATE SET event_type = $1" +
" RETURNING (event_type_nid)"
" DO NOTHING RETURNING (event_type_nid)"
const selectEventTypeNIDSQL = "" +
"SELECT event_type_nid FROM event_types WHERE event_type = $1"

View file

@ -47,9 +47,12 @@ const insertEventSQL = "" +
"INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT event_id_unique" +
" DO UPDATE SET event_id = $1" +
" DO NOTHING" +
" RETURNING event_nid, state_snapshot_nid"
const selectEventSQL = "" +
"SELECT event_nid, state_snapshot_nid FROM events WHERE event_id = $1"
// Bulk lookup of events by string ID.
// Sort by the numeric IDs for event type and state key.
// This means we can use binary search to lookup entries by type and state key.
@ -71,6 +74,7 @@ const bulkSelectStateAtEventAndReferenceSQL = "" +
type eventStatements struct {
insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt
@ -85,6 +89,9 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return
}
if s.selectEventStmt, err = db.Prepare(selectEventSQL); err != nil {
return
}
if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil {
return
}
@ -119,6 +126,13 @@ func (s *eventStatements) insertEvent(
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
}
func (s *eventStatements) selectEvent(eventID string) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64
var stateNID int64
err := s.selectEventStmt.QueryRow(eventID).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
}
func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs))
if err != nil {
@ -134,9 +148,9 @@ func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.S
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventNID,
&result.EventTypeNID,
&result.EventStateKeyNID,
&result.EventNID,
); err != nil {
return nil, err
}
@ -163,9 +177,9 @@ func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventNID,
&result.EventTypeNID,
&result.EventStateKeyNID,
&result.EventNID,
&result.BeforeStateSnapshotNID,
); err != nil {
return nil, err

View file

@ -16,7 +16,7 @@ CREATE TABLE IF NOT EXISTS rooms (
-- The most recent events in the room that aren't referenced by another event.
-- This list may empty if the server hasn't joined the room yet.
-- (The server will be in that state while it stores the events for the initial state of the room)
latest_event_nids BIGINT[] NOT NULL
latest_event_nids BIGINT[] NOT NULL DEFAULT '{}'::BIGINT[]
);
`
@ -24,8 +24,7 @@ CREATE TABLE IF NOT EXISTS rooms (
const insertRoomNIDSQL = "" +
"INSERT INTO rooms (room_id) VALUES ($1)" +
" ON CONFLICT ON CONSTRAINT room_id_unique" +
" DO UPDATE SET room_id = $1" +
" RETURNING (room_nid)"
" DO NOTHING RETURNING (room_nid)"
const selectRoomNIDSQL = "" +
"SELECT room_nid FROM rooms WHERE room_id = $1"

View file

@ -73,7 +73,13 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
event.EventReference().EventSHA256,
authEventNIDs,
); err != nil {
return 0, types.StateAtEvent{}, err
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.statements.selectEvent(event.EventID())
}
if err != nil {
return 0, types.StateAtEvent{}, err
}
}
if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil {
@ -97,12 +103,13 @@ func (d *Database) assignRoomNID(roomID string) (types.RoomNID, error) {
roomNID, err := d.statements.selectRoomNID(roomID)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
return d.statements.insertRoomNID(roomID)
roomNID, err = d.statements.insertRoomNID(roomID)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
roomNID, err = d.statements.selectRoomNID(roomID)
}
}
if err != nil {
return 0, err
}
return roomNID, nil
return roomNID, err
}
func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) {
@ -110,12 +117,13 @@ func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, err
eventTypeNID, err := d.statements.selectEventTypeNID(eventType)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
return d.statements.insertEventTypeNID(eventType)
eventTypeNID, err = d.statements.insertEventTypeNID(eventType)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventTypeNID, err = d.statements.selectEventTypeNID(eventType)
}
}
if err != nil {
return 0, err
}
return eventTypeNID, nil
return eventTypeNID, err
}
func (d *Database) assignStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
@ -123,12 +131,13 @@ func (d *Database) assignStateKeyNID(eventStateKey string) (types.EventStateKeyN
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(eventStateKey)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
return d.statements.insertEventStateKeyNID(eventStateKey)
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(eventStateKey)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(eventStateKey)
}
}
if err != nil {
return 0, err
}
return eventStateKeyNID, nil
return eventStateKeyNID, err
}
// StateEntriesForEventIDs implements input.EventDatabase

8
vendor/manifest vendored
View file

@ -59,6 +59,12 @@
"revision": "7db9049039a047d955fe8c19b83c8ff5abd765c7",
"branch": "master"
},
{
"importpath": "github.com/gorilla/context",
"repository": "https://github.com/gorilla/context",
"revision": "08b5f424b9271eedf6f9f0ce86cb9396ed337a42",
"branch": "master"
},
{
"importpath": "github.com/gorilla/mux",
"repository": "https://github.com/gorilla/mux",
@ -200,4 +206,4 @@
"branch": "master"
}
]
}
}

View file

@ -0,0 +1,27 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,10 @@
context
=======
[![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context)
gorilla/context is a general purpose registry for global request variables.
> Note: gorilla/context, having been born well before `context.Context` existed, does not play well
> with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context

View file

@ -0,0 +1,143 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"net/http"
"sync"
"time"
)
var (
mutex sync.RWMutex
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
)
// Set stores a value for a given key in a given request.
func Set(r *http.Request, key, val interface{}) {
mutex.Lock()
if data[r] == nil {
data[r] = make(map[interface{}]interface{})
datat[r] = time.Now().Unix()
}
data[r][key] = val
mutex.Unlock()
}
// Get returns a value stored for a given key in a given request.
func Get(r *http.Request, key interface{}) interface{} {
mutex.RLock()
if ctx := data[r]; ctx != nil {
value := ctx[key]
mutex.RUnlock()
return value
}
mutex.RUnlock()
return nil
}
// GetOk returns stored value and presence state like multi-value return of map access.
func GetOk(r *http.Request, key interface{}) (interface{}, bool) {
mutex.RLock()
if _, ok := data[r]; ok {
value, ok := data[r][key]
mutex.RUnlock()
return value, ok
}
mutex.RUnlock()
return nil, false
}
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
func GetAll(r *http.Request) map[interface{}]interface{} {
mutex.RLock()
if context, ok := data[r]; ok {
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result
}
mutex.RUnlock()
return nil
}
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
// the request was registered.
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) {
mutex.RLock()
context, ok := data[r]
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result, ok
}
// Delete removes a value stored for a given key in a given request.
func Delete(r *http.Request, key interface{}) {
mutex.Lock()
if data[r] != nil {
delete(data[r], key)
}
mutex.Unlock()
}
// Clear removes all values stored for a given request.
//
// This is usually called by a handler wrapper to clean up request
// variables at the end of a request lifetime. See ClearHandler().
func Clear(r *http.Request) {
mutex.Lock()
clear(r)
mutex.Unlock()
}
// clear is Clear without the lock.
func clear(r *http.Request) {
delete(data, r)
delete(datat, r)
}
// Purge removes request data stored for longer than maxAge, in seconds.
// It returns the amount of requests removed.
//
// If maxAge <= 0, all request data is removed.
//
// This is only used for sanity check: in case context cleaning was not
// properly set some request data can be kept forever, consuming an increasing
// amount of memory. In case this is detected, Purge() must be called
// periodically until the problem is fixed.
func Purge(maxAge int) int {
mutex.Lock()
count := 0
if maxAge <= 0 {
count = len(data)
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
} else {
min := time.Now().Unix() - int64(maxAge)
for r := range data {
if datat[r] < min {
clear(r)
count++
}
}
}
mutex.Unlock()
return count
}
// ClearHandler wraps an http.Handler and clears request values at the end
// of a request lifetime.
func ClearHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer Clear(r)
h.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,161 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"net/http"
"testing"
)
type keyType int
const (
key1 keyType = iota
key2
)
func TestContext(t *testing.T) {
assertEqual := func(val interface{}, exp interface{}) {
if val != exp {
t.Errorf("Expected %v, got %v.", exp, val)
}
}
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
emptyR, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
// Get()
assertEqual(Get(r, key1), nil)
// Set()
Set(r, key1, "1")
assertEqual(Get(r, key1), "1")
assertEqual(len(data[r]), 1)
Set(r, key2, "2")
assertEqual(Get(r, key2), "2")
assertEqual(len(data[r]), 2)
//GetOk
value, ok := GetOk(r, key1)
assertEqual(value, "1")
assertEqual(ok, true)
value, ok = GetOk(r, "not exists")
assertEqual(value, nil)
assertEqual(ok, false)
Set(r, "nil value", nil)
value, ok = GetOk(r, "nil value")
assertEqual(value, nil)
assertEqual(ok, true)
// GetAll()
values := GetAll(r)
assertEqual(len(values), 3)
// GetAll() for empty request
values = GetAll(emptyR)
if values != nil {
t.Error("GetAll didn't return nil value for invalid request")
}
// GetAllOk()
values, ok = GetAllOk(r)
assertEqual(len(values), 3)
assertEqual(ok, true)
// GetAllOk() for empty request
values, ok = GetAllOk(emptyR)
assertEqual(len(values), 0)
assertEqual(ok, false)
// Delete()
Delete(r, key1)
assertEqual(Get(r, key1), nil)
assertEqual(len(data[r]), 2)
Delete(r, key2)
assertEqual(Get(r, key2), nil)
assertEqual(len(data[r]), 1)
// Clear()
Clear(r)
assertEqual(len(data), 0)
}
func parallelReader(r *http.Request, key string, iterations int, wait, done chan struct{}) {
<-wait
for i := 0; i < iterations; i++ {
Get(r, key)
}
done <- struct{}{}
}
func parallelWriter(r *http.Request, key, value string, iterations int, wait, done chan struct{}) {
<-wait
for i := 0; i < iterations; i++ {
Set(r, key, value)
}
done <- struct{}{}
}
func benchmarkMutex(b *testing.B, numReaders, numWriters, iterations int) {
b.StopTimer()
r, _ := http.NewRequest("GET", "http://localhost:8080/", nil)
done := make(chan struct{})
b.StartTimer()
for i := 0; i < b.N; i++ {
wait := make(chan struct{})
for i := 0; i < numReaders; i++ {
go parallelReader(r, "test", iterations, wait, done)
}
for i := 0; i < numWriters; i++ {
go parallelWriter(r, "test", "123", iterations, wait, done)
}
close(wait)
for i := 0; i < numReaders+numWriters; i++ {
<-done
}
}
}
func BenchmarkMutexSameReadWrite1(b *testing.B) {
benchmarkMutex(b, 1, 1, 32)
}
func BenchmarkMutexSameReadWrite2(b *testing.B) {
benchmarkMutex(b, 2, 2, 32)
}
func BenchmarkMutexSameReadWrite4(b *testing.B) {
benchmarkMutex(b, 4, 4, 32)
}
func BenchmarkMutex1(b *testing.B) {
benchmarkMutex(b, 2, 8, 32)
}
func BenchmarkMutex2(b *testing.B) {
benchmarkMutex(b, 16, 4, 64)
}
func BenchmarkMutex3(b *testing.B) {
benchmarkMutex(b, 1, 2, 128)
}
func BenchmarkMutex4(b *testing.B) {
benchmarkMutex(b, 128, 32, 256)
}
func BenchmarkMutex5(b *testing.B) {
benchmarkMutex(b, 1024, 2048, 64)
}
func BenchmarkMutex6(b *testing.B) {
benchmarkMutex(b, 2048, 1024, 512)
}

View file

@ -0,0 +1,88 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package context stores values shared during a request lifetime.
Note: gorilla/context, having been born well before `context.Context` existed,
does not play well > with the shallow copying of the request that
[`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext)
(added to net/http Go 1.7 onwards) performs. You should either use *just*
gorilla/context, or moving forward, the new `http.Request.Context()`.
For example, a router can set variables extracted from the URL and later
application handlers can access those values, or it can be used to store
sessions values to be saved at the end of a request. There are several
others common uses.
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list:
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53
Here's the basic usage: first define the keys that you will need. The key
type is interface{} so a key can be of any type that supports equality.
Here we define a key using a custom int type to avoid name collisions:
package foo
import (
"github.com/gorilla/context"
)
type key int
const MyKey key = 0
Then set a variable. Variables are bound to an http.Request object, so you
need a request instance to set a value:
context.Set(r, MyKey, "bar")
The application can later access the variable using the same key you provided:
func MyHandler(w http.ResponseWriter, r *http.Request) {
// val is "bar".
val := context.Get(r, foo.MyKey)
// returns ("bar", true)
val, ok := context.GetOk(r, foo.MyKey)
// ...
}
And that's all about the basic usage. We discuss some other ideas below.
Any type can be stored in the context. To enforce a given type, make the key
private and wrap Get() and Set() to accept and return values of a specific
type:
type key int
const mykey key = 0
// GetMyKey returns a value for this package from the request values.
func GetMyKey(r *http.Request) SomeType {
if rv := context.Get(r, mykey); rv != nil {
return rv.(SomeType)
}
return nil
}
// SetMyKey sets a value for this package in the request values.
func SetMyKey(r *http.Request, val SomeType) {
context.Set(r, mykey, val)
}
Variables must be cleared at the end of a request, to remove all values
that were stored. This can be done in an http.Handler, after a request was
served. Just call Clear() passing the request:
context.Clear(r)
...or use ClearHandler(), which conveniently wraps an http.Handler to clear
variables at the end of a request lifetime.
The Routers from the packages gorilla/mux and gorilla/pat call Clear()
so if you are using either of them you don't need to clear the context manually.
*/
package context