Add gitea/forgejo sepcific queries support
This commit is contained in:
parent
18acf05b6b
commit
5b50734599
|
@ -8,23 +8,33 @@ import (
|
||||||
"github.com/ps78674/goldap/message"
|
"github.com/ps78674/goldap/message"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *server) buildBleveQuery(f message.Filter) (string, error) {
|
func (s *server) buildBleveQuery(f message.Filter) (string, bool, error) {
|
||||||
s.l.Trace("Building search expression", "type", fmt.Sprintf("%T", f), "filter", fmt.Sprintf("%#v", f))
|
s.l.Trace("Building search expression", "type", fmt.Sprintf("%T", f), "filter", fmt.Sprintf("%#v", f))
|
||||||
var err error
|
var err error
|
||||||
var etmp string
|
var etmp string
|
||||||
var expr []string
|
var expr []string
|
||||||
|
var exprAnd bool = false
|
||||||
switch f := f.(type) {
|
switch f := f.(type) {
|
||||||
case message.FilterEqualityMatch:
|
case message.FilterEqualityMatch:
|
||||||
etmp, err = mapToBleveStringQuery(string(f.AttributeDesc()), "=", string(f.AssertionValue()))
|
etmp, err = mapToBleveStringQuery(string(f.AttributeDesc()), "=", string(f.AssertionValue()))
|
||||||
expr = append(expr, etmp)
|
expr = append(expr, etmp)
|
||||||
case message.FilterOr:
|
case message.FilterOr:
|
||||||
for _, subf := range f {
|
for _, subf := range f {
|
||||||
s, err := s.buildBleveQuery(subf)
|
s, _, err := s.buildBleveQuery(subf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", false, err
|
||||||
}
|
}
|
||||||
expr = append(expr, s)
|
expr = append(expr, s)
|
||||||
}
|
}
|
||||||
|
case message.FilterAnd:
|
||||||
|
for _, subf := range f {
|
||||||
|
s, _, err := s.buildBleveQuery(subf)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
expr = append(expr, s)
|
||||||
|
}
|
||||||
|
exprAnd = true
|
||||||
case message.FilterPresent:
|
case message.FilterPresent:
|
||||||
etmp, err = mapToBleveStringQuery(string(f), "=", "*")
|
etmp, err = mapToBleveStringQuery(string(f), "=", "*")
|
||||||
expr = append(expr, etmp)
|
expr = append(expr, etmp)
|
||||||
|
@ -32,7 +42,7 @@ func (s *server) buildBleveQuery(f message.Filter) (string, error) {
|
||||||
s.l.Warn("Unsupported search filter", "filter", fmt.Sprintf("%#v", f))
|
s.l.Warn("Unsupported search filter", "filter", fmt.Sprintf("%#v", f))
|
||||||
err = errors.New("unsupported search filter")
|
err = errors.New("unsupported search filter")
|
||||||
}
|
}
|
||||||
return strings.Join(expr, " "), err
|
return strings.Join(expr, " "), exprAnd, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func mapToBleveStringQuery(attr, op, val string) (string, error) {
|
func mapToBleveStringQuery(attr, op, val string) (string, error) {
|
||||||
|
@ -43,6 +53,10 @@ func mapToBleveStringQuery(attr, op, val string) (string, error) {
|
||||||
predicate = "ID"
|
predicate = "ID"
|
||||||
case "cn":
|
case "cn":
|
||||||
predicate = "Name"
|
predicate = "Name"
|
||||||
|
case "member":
|
||||||
|
predicate = ""
|
||||||
|
op = ""
|
||||||
|
return val, nil
|
||||||
default:
|
default:
|
||||||
return "", errors.New("search attribute is unsupported")
|
return "", errors.New("search attribute is unsupported")
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,4 +103,9 @@ func (s *server) SetDomain(domain string) {
|
||||||
Scope(ldap.SearchRequestHomeSubtree).
|
Scope(ldap.SearchRequestHomeSubtree).
|
||||||
Label("Search - Entities")
|
Label("Search - Entities")
|
||||||
|
|
||||||
|
s.routes.Search(s.handleSearchGroups).
|
||||||
|
BaseDn(groupSearchDN).
|
||||||
|
Scope(ldap.SearchRequestHomeSubtree).
|
||||||
|
Filter("(&(members=*)(cn=gitea_users))").
|
||||||
|
Label("Search - Group Members")
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) {
|
||||||
s.l.Debug("Search Entities")
|
s.l.Debug("Search Entities")
|
||||||
|
|
||||||
r := m.GetSearchRequest()
|
r := m.GetSearchRequest()
|
||||||
expr, err := s.buildBleveQuery(r.Filter())
|
expr, _, err := s.buildBleveQuery(r.Filter())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If err is non-nil at this point it must mean that
|
// If err is non-nil at this point it must mean that
|
||||||
// the above match didn't find a supported filter.
|
// the above match didn't find a supported filter.
|
||||||
|
@ -77,10 +77,19 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) {
|
||||||
// plumbed down to this level to permit attribute filtering in the
|
// plumbed down to this level to permit attribute filtering in the
|
||||||
// future.
|
// future.
|
||||||
func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn message.LDAPDN, attrs message.AttributeSelection) (message.SearchResultEntry, error) {
|
func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn message.LDAPDN, attrs message.AttributeSelection) (message.SearchResultEntry, error) {
|
||||||
|
s.l.Debug("Attr selection: ", "attrs", attrs, len(attrs))
|
||||||
res := ldap.NewSearchResultEntry("uid=" + e.GetID() + "," + string(dn))
|
res := ldap.NewSearchResultEntry("uid=" + e.GetID() + "," + string(dn))
|
||||||
res.AddAttribute("uid", message.AttributeValue(e.GetID()))
|
res.AddAttribute("uid", message.AttributeValue(e.GetID()))
|
||||||
res.AddAttribute("uidNumber", message.AttributeValue(strconv.Itoa(int(e.GetNumber()))))
|
res.AddAttribute("uidNumber", message.AttributeValue(strconv.Itoa(int(e.GetNumber()))))
|
||||||
|
|
||||||
|
mail, err := s.c.EntityKVGet(ctx, e.GetID(), "mail")
|
||||||
|
//if err != nil {
|
||||||
|
// return res, err
|
||||||
|
//}
|
||||||
|
if len(mail["mail"]) > 0 {
|
||||||
|
res.AddAttribute("mail", message.AttributeValue(mail["mail"][0]))
|
||||||
|
}
|
||||||
|
|
||||||
grps, err := s.c.EntityGroups(ctx, e.GetID())
|
grps, err := s.c.EntityGroups(ctx, e.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
|
@ -93,6 +102,12 @@ func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn messag
|
||||||
}
|
}
|
||||||
res.AddAttribute("memberOf", memberOf...)
|
res.AddAttribute("memberOf", memberOf...)
|
||||||
|
|
||||||
|
entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",")
|
||||||
|
s.routes.Search(s.handleSearchEntities).
|
||||||
|
BaseDn(entitySearchDN).
|
||||||
|
Scope(ldap.SearchRequestHomeSubtree).
|
||||||
|
Label("Search - Entities (" + e.GetID() + ")")
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +117,7 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) {
|
||||||
|
|
||||||
r := m.GetSearchRequest()
|
r := m.GetSearchRequest()
|
||||||
|
|
||||||
expr, err := s.buildBleveQuery(r.Filter())
|
expr, exprAnd, err := s.buildBleveQuery(r.Filter())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If err is non-nil at this point it must mean that
|
// If err is non-nil at this point it must mean that
|
||||||
// the above match didn't find a supported filter.
|
// the above match didn't find a supported filter.
|
||||||
|
@ -115,6 +130,25 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) {
|
||||||
|
|
||||||
s.l.Debug("Searching groups", "expr", expr)
|
s.l.Debug("Searching groups", "expr", expr)
|
||||||
|
|
||||||
|
var matchUid string
|
||||||
|
if exprAnd {
|
||||||
|
s.l.Debug("Filtering by AND...")
|
||||||
|
exprSlice := strings.Split(expr, " ")
|
||||||
|
for i := range exprSlice {
|
||||||
|
if !strings.Contains(exprSlice[i], "Name") && !strings.Contains(exprSlice[i], "ID") {
|
||||||
|
matchUid = exprSlice[i]
|
||||||
|
if (i+1) == len(exprSlice){
|
||||||
|
exprSlice = exprSlice[:i]
|
||||||
|
} else {
|
||||||
|
exprSlice = append(exprSlice[:i], exprSlice[i+1:]...)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expr = strings.Join(exprSlice, " ")
|
||||||
|
s.l.Debug("Filtered by AND", "expr", expr, "matchUid", matchUid)
|
||||||
|
}
|
||||||
|
|
||||||
groups, err := s.c.GroupSearch(ctx, expr)
|
groups, err := s.c.GroupSearch(ctx, expr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError)
|
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError)
|
||||||
|
@ -125,7 +159,20 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) {
|
||||||
|
|
||||||
for i := range groups {
|
for i := range groups {
|
||||||
s.l.Debug("Found group", "group", groups[i].GetName())
|
s.l.Debug("Found group", "group", groups[i].GetName())
|
||||||
e, err := s.groupSearchResult(ctx, groups[i], r.BaseObject(), r.Attributes())
|
var e message.SearchResultEntry
|
||||||
|
var complete bool
|
||||||
|
var err error
|
||||||
|
if exprAnd {
|
||||||
|
s.l.Debug("Match UID: ", matchUid)
|
||||||
|
e, complete, err = s.groupSearchResultWithMatch(ctx, groups[i], r.BaseObject(), r.Attributes(), matchUid)
|
||||||
|
if !complete {
|
||||||
|
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess)
|
||||||
|
w.Write(res)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
e, err = s.groupSearchResult(ctx, groups[i], r.BaseObject(), r.Attributes())
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError)
|
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError)
|
||||||
res.SetDiagnosticMessage(err.Error())
|
res.SetDiagnosticMessage(err.Error())
|
||||||
|
@ -162,3 +209,27 @@ func (s *server) groupSearchResult(ctx context.Context, g *pb.Group, dn message.
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) groupSearchResultWithMatch(ctx context.Context, g *pb.Group, dn message.LDAPDN, attrs message.AttributeSelection, matchUid string) (message.SearchResultEntry, bool, error) {
|
||||||
|
res := ldap.NewSearchResultEntry("cn=" + g.GetName() + "," + string(dn))
|
||||||
|
res.AddAttribute("cn", message.AttributeValue(g.GetName()))
|
||||||
|
res.AddAttribute("gidNumber", message.AttributeValue(strconv.Itoa(int(g.GetNumber()))))
|
||||||
|
|
||||||
|
members, err := s.c.GroupMembers(ctx, g.GetName())
|
||||||
|
if err != nil {
|
||||||
|
return res, true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var complete bool = false
|
||||||
|
memberList := []message.AttributeValue{}
|
||||||
|
for i := range members {
|
||||||
|
if members[i].GetID() == matchUid {
|
||||||
|
g := "uid=" + members[i].GetID() + ",ou=entities," + strings.Join(s.nc, ",")
|
||||||
|
memberList = append(memberList, message.AttributeValue(g))
|
||||||
|
complete = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res.AddAttribute("member", memberList...)
|
||||||
|
|
||||||
|
return res, complete, nil
|
||||||
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ type naClient interface {
|
||||||
AuthEntity(context.Context, string, string) error
|
AuthEntity(context.Context, string, string) error
|
||||||
EntitySearch(context.Context, string) ([]*pb.Entity, error)
|
EntitySearch(context.Context, string) ([]*pb.Entity, error)
|
||||||
EntityGroups(context.Context, string) ([]*pb.Group, error)
|
EntityGroups(context.Context, string) ([]*pb.Group, error)
|
||||||
|
EntityKVGet(context.Context, string, string) (map[string][]string, error)
|
||||||
|
|
||||||
GroupSearch(context.Context, string) ([]*pb.Group, error)
|
GroupSearch(context.Context, string) ([]*pb.Group, error)
|
||||||
GroupMembers(context.Context, string) ([]*pb.Entity, error)
|
GroupMembers(context.Context, string) ([]*pb.Entity, error)
|
||||||
|
|
Loading…
Reference in a new issue