Add memberOf filtering for user search

This commit is contained in:
signaryk 2024-07-05 16:05:57 -05:00
parent 5b50734599
commit c2eca7ed43
2 changed files with 76 additions and 6 deletions

View file

@ -54,8 +54,8 @@ func mapToBleveStringQuery(attr, op, val string) (string, error) {
case "cn": case "cn":
predicate = "Name" predicate = "Name"
case "member": case "member":
predicate = "" return val, nil
op = "" case "memberOf":
return val, nil return val, nil
default: default:
return "", errors.New("search attribute is unsupported") return "", errors.New("search attribute is unsupported")

View file

@ -33,7 +33,7 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) {
s.l.Debug("Search Entities") s.l.Debug("Search Entities")
r := m.GetSearchRequest() r := m.GetSearchRequest()
expr, _, err := s.buildBleveQuery(r.Filter()) expr, exprAnd, err := s.buildBleveQuery(r.Filter())
if err != nil { if err != nil {
// If err is non-nil at this point it must mean that // If err is non-nil at this point it must mean that
// the above match didn't find a supported filter. // the above match didn't find a supported filter.
@ -46,6 +46,25 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) {
s.l.Debug("Searching entities", "query", expr) s.l.Debug("Searching entities", "query", expr)
var matchGroup string
if exprAnd {
s.l.Debug("Filtering by AND...")
exprSlice := strings.Split(expr, " ")
for i := range exprSlice {
if !strings.Contains(exprSlice[i], "Name") && !strings.Contains(exprSlice[i], "ID") {
matchGroup = exprSlice[i]
if (i + 1) == len(exprSlice) {
exprSlice = exprSlice[:i]
} else {
exprSlice = append(exprSlice[:i], exprSlice[i+1:]...)
}
break
}
}
expr = strings.Join(exprSlice, " ")
s.l.Debug("Filtered by AND", "expr", expr, "matchUid", matchGroup)
}
ents, err := s.c.EntitySearch(ctx, expr) ents, err := s.c.EntitySearch(ctx, expr)
if err != nil { if err != nil {
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError) res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError)
@ -55,7 +74,19 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) {
} }
for i := range ents { for i := range ents {
e, err := s.entitySearchResult(ctx, ents[i], r.BaseObject(), r.Attributes()) var e message.SearchResultEntry
var complete bool
var err error
if exprAnd {
e, complete, err = s.entitySearchResultWithMatch(ctx, ents[i], r.BaseObject(), r.Attributes(), matchGroup)
if !complete {
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess)
w.Write(res)
return
}
} else {
e, err = s.entitySearchResult(ctx, ents[i], r.BaseObject(), r.Attributes())
}
if err != nil { if err != nil {
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError) res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError)
res.SetDiagnosticMessage(err.Error()) res.SetDiagnosticMessage(err.Error())
@ -102,7 +133,7 @@ func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn messag
} }
res.AddAttribute("memberOf", memberOf...) res.AddAttribute("memberOf", memberOf...)
entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",") entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",")
s.routes.Search(s.handleSearchEntities). s.routes.Search(s.handleSearchEntities).
BaseDn(entitySearchDN). BaseDn(entitySearchDN).
Scope(ldap.SearchRequestHomeSubtree). Scope(ldap.SearchRequestHomeSubtree).
@ -111,6 +142,45 @@ func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn messag
return res, nil return res, nil
} }
func (s *server) entitySearchResultWithMatch(ctx context.Context, e *pb.Entity, dn message.LDAPDN, attrs message.AttributeSelection, matchGroup string) (message.SearchResultEntry, bool, error) {
s.l.Debug("Attr selection: ", "attrs", attrs, len(attrs))
res := ldap.NewSearchResultEntry("uid=" + e.GetID() + "," + string(dn))
res.AddAttribute("uid", message.AttributeValue(e.GetID()))
res.AddAttribute("uidNumber", message.AttributeValue(strconv.Itoa(int(e.GetNumber()))))
res.AddAttribute("displayName", message.AttributeValue(e.GetMeta().GetDisplayName()))
mail, err := s.c.EntityKVGet(ctx, e.GetID(), "mail")
if len(mail["mail"]) > 0 {
res.AddAttribute("mail", message.AttributeValue(mail["mail"][0]))
}
grps, err := s.c.EntityGroups(ctx, e.GetID())
if err != nil {
return res, true, err
}
memberOf := []message.AttributeValue{}
for i := range grps {
if grps[i].GetName() == matchGroup {
g := "cn=" + grps[i].GetName() + ",ou=groups," + strings.Join(s.nc, ",")
memberOf = append(memberOf, message.AttributeValue(g))
}
}
if len(memberOf) > 0 {
res.AddAttribute("memberOf", memberOf...)
} else {
return res, false, nil
}
entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",")
s.routes.Search(s.handleSearchEntities).
BaseDn(entitySearchDN).
Scope(ldap.SearchRequestHomeSubtree).
Label("Search - Entities (" + e.GetID() + ")")
return res, true, nil
}
func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) { func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) {
ctx := context.Background() ctx := context.Background()
s.l.Debug("Search Groups") s.l.Debug("Search Groups")
@ -137,7 +207,7 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) {
for i := range exprSlice { for i := range exprSlice {
if !strings.Contains(exprSlice[i], "Name") && !strings.Contains(exprSlice[i], "ID") { if !strings.Contains(exprSlice[i], "Name") && !strings.Contains(exprSlice[i], "ID") {
matchUid = exprSlice[i] matchUid = exprSlice[i]
if (i+1) == len(exprSlice){ if (i + 1) == len(exprSlice) {
exprSlice = exprSlice[:i] exprSlice = exprSlice[:i]
} else { } else {
exprSlice = append(exprSlice[:i], exprSlice[i+1:]...) exprSlice = append(exprSlice[:i], exprSlice[i+1:]...)