From c2eca7ed43c5522fdd35d2b10a2f5bd93150d66a Mon Sep 17 00:00:00 2001 From: William Luke Date: Fri, 5 Jul 2024 16:05:57 -0500 Subject: [PATCH] Add memberOf filtering for user search --- internal/ldap/bleve_mapper.go | 4 +- internal/ldap/search.go | 78 +++++++++++++++++++++++++++++++++-- 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/internal/ldap/bleve_mapper.go b/internal/ldap/bleve_mapper.go index 8421acb..4066ccf 100644 --- a/internal/ldap/bleve_mapper.go +++ b/internal/ldap/bleve_mapper.go @@ -54,8 +54,8 @@ func mapToBleveStringQuery(attr, op, val string) (string, error) { case "cn": predicate = "Name" case "member": - predicate = "" - op = "" + return val, nil + case "memberOf": return val, nil default: return "", errors.New("search attribute is unsupported") diff --git a/internal/ldap/search.go b/internal/ldap/search.go index 41ff024..5613e6b 100644 --- a/internal/ldap/search.go +++ b/internal/ldap/search.go @@ -33,7 +33,7 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) { s.l.Debug("Search Entities") r := m.GetSearchRequest() - expr, _, err := s.buildBleveQuery(r.Filter()) + expr, exprAnd, err := s.buildBleveQuery(r.Filter()) if err != nil { // If err is non-nil at this point it must mean that // the above match didn't find a supported filter. @@ -46,6 +46,25 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) { s.l.Debug("Searching entities", "query", expr) + var matchGroup string + if exprAnd { + s.l.Debug("Filtering by AND...") + exprSlice := strings.Split(expr, " ") + for i := range exprSlice { + if !strings.Contains(exprSlice[i], "Name") && !strings.Contains(exprSlice[i], "ID") { + matchGroup = exprSlice[i] + if (i + 1) == len(exprSlice) { + exprSlice = exprSlice[:i] + } else { + exprSlice = append(exprSlice[:i], exprSlice[i+1:]...) + } + break + } + } + expr = strings.Join(exprSlice, " ") + s.l.Debug("Filtered by AND", "expr", expr, "matchUid", matchGroup) + } + ents, err := s.c.EntitySearch(ctx, expr) if err != nil { res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError) @@ -55,7 +74,19 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) { } for i := range ents { - e, err := s.entitySearchResult(ctx, ents[i], r.BaseObject(), r.Attributes()) + var e message.SearchResultEntry + var complete bool + var err error + if exprAnd { + e, complete, err = s.entitySearchResultWithMatch(ctx, ents[i], r.BaseObject(), r.Attributes(), matchGroup) + if !complete { + res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess) + w.Write(res) + return + } + } else { + e, err = s.entitySearchResult(ctx, ents[i], r.BaseObject(), r.Attributes()) + } if err != nil { res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError) res.SetDiagnosticMessage(err.Error()) @@ -102,7 +133,7 @@ func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn messag } res.AddAttribute("memberOf", memberOf...) - entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",") + entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",") s.routes.Search(s.handleSearchEntities). BaseDn(entitySearchDN). Scope(ldap.SearchRequestHomeSubtree). @@ -111,6 +142,45 @@ func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn messag return res, nil } +func (s *server) entitySearchResultWithMatch(ctx context.Context, e *pb.Entity, dn message.LDAPDN, attrs message.AttributeSelection, matchGroup string) (message.SearchResultEntry, bool, error) { + s.l.Debug("Attr selection: ", "attrs", attrs, len(attrs)) + res := ldap.NewSearchResultEntry("uid=" + e.GetID() + "," + string(dn)) + res.AddAttribute("uid", message.AttributeValue(e.GetID())) + res.AddAttribute("uidNumber", message.AttributeValue(strconv.Itoa(int(e.GetNumber())))) + res.AddAttribute("displayName", message.AttributeValue(e.GetMeta().GetDisplayName())) + + mail, err := s.c.EntityKVGet(ctx, e.GetID(), "mail") + if len(mail["mail"]) > 0 { + res.AddAttribute("mail", message.AttributeValue(mail["mail"][0])) + } + + grps, err := s.c.EntityGroups(ctx, e.GetID()) + if err != nil { + return res, true, err + } + + memberOf := []message.AttributeValue{} + for i := range grps { + if grps[i].GetName() == matchGroup { + g := "cn=" + grps[i].GetName() + ",ou=groups," + strings.Join(s.nc, ",") + memberOf = append(memberOf, message.AttributeValue(g)) + } + } + if len(memberOf) > 0 { + res.AddAttribute("memberOf", memberOf...) + } else { + return res, false, nil + } + + entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",") + s.routes.Search(s.handleSearchEntities). + BaseDn(entitySearchDN). + Scope(ldap.SearchRequestHomeSubtree). + Label("Search - Entities (" + e.GetID() + ")") + + return res, true, nil +} + func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) { ctx := context.Background() s.l.Debug("Search Groups") @@ -137,7 +207,7 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) { for i := range exprSlice { if !strings.Contains(exprSlice[i], "Name") && !strings.Contains(exprSlice[i], "ID") { matchUid = exprSlice[i] - if (i+1) == len(exprSlice){ + if (i + 1) == len(exprSlice) { exprSlice = exprSlice[:i] } else { exprSlice = append(exprSlice[:i], exprSlice[i+1:]...)