From 5b507345990ec001f7f9e4e9d5b00555ba4d3b89 Mon Sep 17 00:00:00 2001 From: William Luke Date: Thu, 4 Jul 2024 10:07:25 -0500 Subject: [PATCH] Add gitea/forgejo sepcific queries support --- internal/ldap/bleve_mapper.go | 22 ++++++++-- internal/ldap/ldap.go | 5 +++ internal/ldap/search.go | 77 +++++++++++++++++++++++++++++++++-- internal/ldap/type.go | 1 + 4 files changed, 98 insertions(+), 7 deletions(-) diff --git a/internal/ldap/bleve_mapper.go b/internal/ldap/bleve_mapper.go index 6e4d680..8421acb 100644 --- a/internal/ldap/bleve_mapper.go +++ b/internal/ldap/bleve_mapper.go @@ -8,23 +8,33 @@ import ( "github.com/ps78674/goldap/message" ) -func (s *server) buildBleveQuery(f message.Filter) (string, error) { +func (s *server) buildBleveQuery(f message.Filter) (string, bool, error) { s.l.Trace("Building search expression", "type", fmt.Sprintf("%T", f), "filter", fmt.Sprintf("%#v", f)) var err error var etmp string var expr []string + var exprAnd bool = false switch f := f.(type) { case message.FilterEqualityMatch: etmp, err = mapToBleveStringQuery(string(f.AttributeDesc()), "=", string(f.AssertionValue())) expr = append(expr, etmp) case message.FilterOr: for _, subf := range f { - s, err := s.buildBleveQuery(subf) + s, _, err := s.buildBleveQuery(subf) if err != nil { - return "", err + return "", false, err } expr = append(expr, s) } + case message.FilterAnd: + for _, subf := range f { + s, _, err := s.buildBleveQuery(subf) + if err != nil { + return "", false, err + } + expr = append(expr, s) + } + exprAnd = true case message.FilterPresent: etmp, err = mapToBleveStringQuery(string(f), "=", "*") expr = append(expr, etmp) @@ -32,7 +42,7 @@ func (s *server) buildBleveQuery(f message.Filter) (string, error) { s.l.Warn("Unsupported search filter", "filter", fmt.Sprintf("%#v", f)) err = errors.New("unsupported search filter") } - return strings.Join(expr, " "), err + return strings.Join(expr, " "), exprAnd, err } func mapToBleveStringQuery(attr, op, val string) (string, error) { @@ -43,6 +53,10 @@ func mapToBleveStringQuery(attr, op, val string) (string, error) { predicate = "ID" case "cn": predicate = "Name" + case "member": + predicate = "" + op = "" + return val, nil default: return "", errors.New("search attribute is unsupported") } diff --git a/internal/ldap/ldap.go b/internal/ldap/ldap.go index 0527c1c..a286fc2 100644 --- a/internal/ldap/ldap.go +++ b/internal/ldap/ldap.go @@ -103,4 +103,9 @@ func (s *server) SetDomain(domain string) { Scope(ldap.SearchRequestHomeSubtree). Label("Search - Entities") + s.routes.Search(s.handleSearchGroups). + BaseDn(groupSearchDN). + Scope(ldap.SearchRequestHomeSubtree). + Filter("(&(members=*)(cn=gitea_users))"). + Label("Search - Group Members") } diff --git a/internal/ldap/search.go b/internal/ldap/search.go index 6fdb511..41ff024 100644 --- a/internal/ldap/search.go +++ b/internal/ldap/search.go @@ -33,7 +33,7 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) { s.l.Debug("Search Entities") r := m.GetSearchRequest() - expr, err := s.buildBleveQuery(r.Filter()) + expr, _, err := s.buildBleveQuery(r.Filter()) if err != nil { // If err is non-nil at this point it must mean that // the above match didn't find a supported filter. @@ -77,10 +77,19 @@ func (s *server) handleSearchEntities(w ldap.ResponseWriter, m *ldap.Message) { // plumbed down to this level to permit attribute filtering in the // future. func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn message.LDAPDN, attrs message.AttributeSelection) (message.SearchResultEntry, error) { + s.l.Debug("Attr selection: ", "attrs", attrs, len(attrs)) res := ldap.NewSearchResultEntry("uid=" + e.GetID() + "," + string(dn)) res.AddAttribute("uid", message.AttributeValue(e.GetID())) res.AddAttribute("uidNumber", message.AttributeValue(strconv.Itoa(int(e.GetNumber())))) + mail, err := s.c.EntityKVGet(ctx, e.GetID(), "mail") + //if err != nil { + // return res, err + //} + if len(mail["mail"]) > 0 { + res.AddAttribute("mail", message.AttributeValue(mail["mail"][0])) + } + grps, err := s.c.EntityGroups(ctx, e.GetID()) if err != nil { return res, err @@ -93,6 +102,12 @@ func (s *server) entitySearchResult(ctx context.Context, e *pb.Entity, dn messag } res.AddAttribute("memberOf", memberOf...) + entitySearchDN := "uid=" + e.GetID() + ",ou=entities," + strings.Join(s.nc, ",") + s.routes.Search(s.handleSearchEntities). + BaseDn(entitySearchDN). + Scope(ldap.SearchRequestHomeSubtree). + Label("Search - Entities (" + e.GetID() + ")") + return res, nil } @@ -102,7 +117,7 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) { r := m.GetSearchRequest() - expr, err := s.buildBleveQuery(r.Filter()) + expr, exprAnd, err := s.buildBleveQuery(r.Filter()) if err != nil { // If err is non-nil at this point it must mean that // the above match didn't find a supported filter. @@ -115,6 +130,25 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) { s.l.Debug("Searching groups", "expr", expr) + var matchUid string + if exprAnd { + s.l.Debug("Filtering by AND...") + exprSlice := strings.Split(expr, " ") + for i := range exprSlice { + if !strings.Contains(exprSlice[i], "Name") && !strings.Contains(exprSlice[i], "ID") { + matchUid = exprSlice[i] + if (i+1) == len(exprSlice){ + exprSlice = exprSlice[:i] + } else { + exprSlice = append(exprSlice[:i], exprSlice[i+1:]...) + } + break + } + } + expr = strings.Join(exprSlice, " ") + s.l.Debug("Filtered by AND", "expr", expr, "matchUid", matchUid) + } + groups, err := s.c.GroupSearch(ctx, expr) if err != nil { res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError) @@ -125,7 +159,20 @@ func (s *server) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) { for i := range groups { s.l.Debug("Found group", "group", groups[i].GetName()) - e, err := s.groupSearchResult(ctx, groups[i], r.BaseObject(), r.Attributes()) + var e message.SearchResultEntry + var complete bool + var err error + if exprAnd { + s.l.Debug("Match UID: ", matchUid) + e, complete, err = s.groupSearchResultWithMatch(ctx, groups[i], r.BaseObject(), r.Attributes(), matchUid) + if !complete { + res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess) + w.Write(res) + return + } + } else { + e, err = s.groupSearchResult(ctx, groups[i], r.BaseObject(), r.Attributes()) + } if err != nil { res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultOperationsError) res.SetDiagnosticMessage(err.Error()) @@ -162,3 +209,27 @@ func (s *server) groupSearchResult(ctx context.Context, g *pb.Group, dn message. return res, nil } + +func (s *server) groupSearchResultWithMatch(ctx context.Context, g *pb.Group, dn message.LDAPDN, attrs message.AttributeSelection, matchUid string) (message.SearchResultEntry, bool, error) { + res := ldap.NewSearchResultEntry("cn=" + g.GetName() + "," + string(dn)) + res.AddAttribute("cn", message.AttributeValue(g.GetName())) + res.AddAttribute("gidNumber", message.AttributeValue(strconv.Itoa(int(g.GetNumber())))) + + members, err := s.c.GroupMembers(ctx, g.GetName()) + if err != nil { + return res, true, err + } + + var complete bool = false + memberList := []message.AttributeValue{} + for i := range members { + if members[i].GetID() == matchUid { + g := "uid=" + members[i].GetID() + ",ou=entities," + strings.Join(s.nc, ",") + memberList = append(memberList, message.AttributeValue(g)) + complete = true + } + } + res.AddAttribute("member", memberList...) + + return res, complete, nil +} diff --git a/internal/ldap/type.go b/internal/ldap/type.go index d36c03f..f4fdd00 100644 --- a/internal/ldap/type.go +++ b/internal/ldap/type.go @@ -13,6 +13,7 @@ type naClient interface { AuthEntity(context.Context, string, string) error EntitySearch(context.Context, string) ([]*pb.Entity, error) EntityGroups(context.Context, string) ([]*pb.Group, error) + EntityKVGet(context.Context, string, string) (map[string][]string, error) GroupSearch(context.Context, string) ([]*pb.Group, error) GroupMembers(context.Context, string) ([]*pb.Entity, error)