diff --git a/go.mod b/go.mod index 44c05136..ffd3f928 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/crewjam/saml go 1.22 require ( - github.com/golang-jwt/jwt/v5 v5.2.2 github.com/beevik/etree v1.5.0 + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/go-cmp v0.7.0 github.com/mattermost/xml-roundtrip-validator v0.1.0 github.com/russellhaering/goxmldsig v1.4.0 diff --git a/service_provider.go b/service_provider.go index c97886d0..ae919480 100644 --- a/service_provider.go +++ b/service_provider.go @@ -1812,6 +1812,10 @@ func findChild(parentEl *etree.Element, childNS string, childTag string) (*etree func elementToBytes(el *etree.Element) ([]byte, error) { namespaces := map[string]string{} for _, childEl := range el.FindElements("//*") { + if el.Tag != childEl.Tag { + continue + } + ns := childEl.NamespaceURI() if ns != "" { namespaces[childEl.Space] = ns @@ -1821,7 +1825,11 @@ func elementToBytes(el *etree.Element) ([]byte, error) { doc := etree.NewDocument() doc.SetRoot(el.Copy()) for space, uri := range namespaces { - doc.Root().CreateAttr("xmlns:"+space, uri) + if space == "" { + doc.Root().CreateAttr("xmlns", uri) + } else { + doc.Root().CreateAttr("xmlns:"+space, uri) + } } return doc.WriteToBytes() @@ -1833,6 +1841,7 @@ func unmarshalElement(el *etree.Element, v interface{}) error { if err != nil { return err } + return xml.Unmarshal(buf, v) }