From 5d10f8ced82b7e794948ab8820480c542379eca6 Mon Sep 17 00:00:00 2001 From: pxp928 Date: Mon, 1 Jul 2024 16:00:55 -0400 Subject: [PATCH 1/3] add missing nodes from the node query Signed-off-by: pxp928 --- .../backends/ent/backend/neighbors.go | 178 +++++++++++------- pkg/assembler/backends/ent/backend/package.go | 61 ++++++ pkg/assembler/backends/ent/backend/source.go | 54 ++++++ .../backends/ent/backend/vulnerability.go | 21 +++ 4 files changed, 242 insertions(+), 72 deletions(-) diff --git a/pkg/assembler/backends/ent/backend/neighbors.go b/pkg/assembler/backends/ent/backend/neighbors.go index 812c9d7413..9d19415202 100644 --- a/pkg/assembler/backends/ent/backend/neighbors.go +++ b/pkg/assembler/backends/ent/backend/neighbors.go @@ -297,23 +297,22 @@ func (b *EntBackend) Node(ctx context.Context, node string) (model.Node, error) if foundGlobalID.nodeType == "" { return nil, fmt.Errorf("failed to parse globalID %s. Missing Node Type", node) } - // return uuid if valid, else error - nodeID, err := uuid.Parse(foundGlobalID.id) - if err != nil { - return nil, fmt.Errorf("uuid conversion from string failed with error: %w", err) - } - switch foundGlobalID.nodeType { case artifact.Table: - artifacts, err := b.Artifacts(ctx, &model.ArtifactSpec{ID: ptrfrom.String(nodeID.String())}) + artifacts, err := b.Artifacts(ctx, &model.ArtifactSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for Artifacts via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for Artifacts via ID: %s, with error: %w", foundGlobalID.id, err) } if len(artifacts) != 1 { - return nil, fmt.Errorf("ID returned multiple Artifacts nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple Artifacts nodes %s", foundGlobalID.id) } return artifacts[0], nil case packageversion.Table: + // return uuid if valid, else error + nodeID, err := uuid.Parse(foundGlobalID.id) + if err != nil { + return nil, fmt.Errorf("uuid conversion from string failed with error: %w", err) + } pv, err := b.client.PackageVersion.Query(). Where(packageversion.ID(nodeID)). WithName(func(q *ent.PackageNameQuery) {}). @@ -323,6 +322,11 @@ func (b *EntBackend) Node(ctx context.Context, node string) (model.Node, error) } return toModelPackage(backReferencePackageVersion(pv)), nil case packagename.Table: + // return uuid if valid, else error + nodeID, err := uuid.Parse(foundGlobalID.id) + if err != nil { + return nil, fmt.Errorf("uuid conversion from string failed with error: %w", err) + } pn, err := b.client.PackageName.Query(). Where(packagename.ID(nodeID)). WithVersions(). @@ -331,193 +335,223 @@ func (b *EntBackend) Node(ctx context.Context, node string) (model.Node, error) return nil, err } return toModelPackage(backReferencePackageName(pn)), nil + case pkgNamespaceString: + pNamespace, err := b.getPkgNameSpace(ctx, foundGlobalID.id) + if err != nil { + return nil, fmt.Errorf("failed to get package namespace node with ID: %s, with error: %w", foundGlobalID.id, err) + } + return pNamespace, nil + case pkgTypeString: + pType, err := b.getPkgType(ctx, foundGlobalID.id) + if err != nil { + return nil, fmt.Errorf("failed to get package Type node with ID: %s, with error: %w", foundGlobalID.id, err) + } + return pType, nil case sourcename.Table: - sources, err := b.Sources(ctx, &model.SourceSpec{ID: ptrfrom.String(nodeID.String())}) + sources, err := b.Sources(ctx, &model.SourceSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for Sources via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for Sources via ID: %s, with error: %w", foundGlobalID.id, err) } if len(sources) != 1 { - return nil, fmt.Errorf("ID returned multiple Sources nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple Sources nodes %s", foundGlobalID.id) } return sources[0], nil + case srcNamespaceString: + sNamespace, err := b.getSrcNameSpace(ctx, foundGlobalID.id) + if err != nil { + return nil, fmt.Errorf("failed to get source namespace node with ID: %s, with error: %w", foundGlobalID.id, err) + } + return sNamespace, nil + case srcTypeString: + sType, err := b.getSrcType(ctx, foundGlobalID.id) + if err != nil { + return nil, fmt.Errorf("failed to get source Type node with ID: %s, with error: %w", foundGlobalID.id, err) + } + return sType, nil case builder.Table: - builders, err := b.Builders(ctx, &model.BuilderSpec{ID: ptrfrom.String(nodeID.String())}) + builders, err := b.Builders(ctx, &model.BuilderSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for Builders via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for Builders via ID: %s, with error: %w", foundGlobalID.id, err) } if len(builders) != 1 { - return nil, fmt.Errorf("ID returned multiple Builders nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple Builders nodes %s", foundGlobalID.id) } return builders[0], nil case license.Table: - licenses, err := b.Licenses(ctx, &model.LicenseSpec{ID: ptrfrom.String(nodeID.String())}) + licenses, err := b.Licenses(ctx, &model.LicenseSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for Licenses via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for Licenses via ID: %s, with error: %w", foundGlobalID.id, err) } if len(licenses) != 1 { - return nil, fmt.Errorf("ID returned multiple Licenses nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple Licenses nodes %s", foundGlobalID.id) } return licenses[0], nil case vulnerabilityid.Table: - vulnerabilities, err := b.Vulnerabilities(ctx, &model.VulnerabilitySpec{ID: ptrfrom.String(nodeID.String())}) + vulnerabilities, err := b.Vulnerabilities(ctx, &model.VulnerabilitySpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for Vulnerabilities via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for Vulnerabilities via ID: %s, with error: %w", foundGlobalID.id, err) } if len(vulnerabilities) != 1 { - return nil, fmt.Errorf("ID returned multiple Vulnerabilities nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple Vulnerabilities nodes %s", foundGlobalID.id) } return vulnerabilities[0], nil + case vulnTypeString: + vType, err := b.getVulnType(ctx, foundGlobalID.id) + if err != nil { + return nil, fmt.Errorf("failed to get vulnerability Type node with ID: %s, with error: %w", foundGlobalID.id, err) + } + return vType, nil case certifyBadString: - certs, err := b.CertifyBad(ctx, &model.CertifyBadSpec{ID: ptrfrom.String(nodeID.String())}) + certs, err := b.CertifyBad(ctx, &model.CertifyBadSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for CertifyBad via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for CertifyBad via ID: %s, with error: %w", foundGlobalID.id, err) } if len(certs) != 1 { - return nil, fmt.Errorf("ID returned multiple CertifyBad nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple CertifyBad nodes %s", foundGlobalID.id) } return certs[0], nil case certifyGoodString: - certs, err := b.CertifyGood(ctx, &model.CertifyGoodSpec{ID: ptrfrom.String(nodeID.String())}) + certs, err := b.CertifyGood(ctx, &model.CertifyGoodSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for CertifyGood via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for CertifyGood via ID: %s, with error: %w", foundGlobalID.id, err) } if len(certs) != 1 { - return nil, fmt.Errorf("ID returned multiple CertifyGood nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple CertifyGood nodes %s", foundGlobalID.id) } return certs[0], nil case certifylegal.Table: - legals, err := b.CertifyLegal(ctx, &model.CertifyLegalSpec{ID: ptrfrom.String(nodeID.String())}) + legals, err := b.CertifyLegal(ctx, &model.CertifyLegalSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for CertifyLegal via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for CertifyLegal via ID: %s, with error: %w", foundGlobalID.id, err) } if len(legals) != 1 { - return nil, fmt.Errorf("ID returned multiple CertifyLegal nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple CertifyLegal nodes %s", foundGlobalID.id) } return legals[0], nil case certifyscorecard.Table: - scores, err := b.Scorecards(ctx, &model.CertifyScorecardSpec{ID: ptrfrom.String(nodeID.String())}) + scores, err := b.Scorecards(ctx, &model.CertifyScorecardSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for scorecard via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for scorecard via ID: %s, with error: %w", foundGlobalID.id, err) } if len(scores) != 1 { - return nil, fmt.Errorf("ID returned multiple scorecard nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple scorecard nodes %s", foundGlobalID.id) } return scores[0], nil case certifyvex.Table: - vexs, err := b.CertifyVEXStatement(ctx, &model.CertifyVEXStatementSpec{ID: ptrfrom.String(nodeID.String())}) + vexs, err := b.CertifyVEXStatement(ctx, &model.CertifyVEXStatementSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for CertifyVEXStatement via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for CertifyVEXStatement via ID: %s, with error: %w", foundGlobalID.id, err) } if len(vexs) != 1 { - return nil, fmt.Errorf("ID returned multiple CertifyVEXStatement nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple CertifyVEXStatement nodes %s", foundGlobalID.id) } return vexs[0], nil case certifyvuln.Table: - vulns, err := b.CertifyVuln(ctx, &model.CertifyVulnSpec{ID: ptrfrom.String(nodeID.String())}) + vulns, err := b.CertifyVuln(ctx, &model.CertifyVulnSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for CertifyVuln via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for CertifyVuln via ID: %s, with error: %w", foundGlobalID.id, err) } if len(vulns) != 1 { - return nil, fmt.Errorf("ID returned multiple CertifyVuln nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple CertifyVuln nodes %s", foundGlobalID.id) } return vulns[0], nil case hashequal.Table: - hes, err := b.HashEqual(ctx, &model.HashEqualSpec{ID: ptrfrom.String(nodeID.String())}) + hes, err := b.HashEqual(ctx, &model.HashEqualSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for HashEqual via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for HashEqual via ID: %s, with error: %w", foundGlobalID.id, err) } if len(hes) != 1 { - return nil, fmt.Errorf("ID returned multiple HashEqual nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple HashEqual nodes %s", foundGlobalID.id) } return hes[0], nil case hasmetadata.Table: - hms, err := b.HasMetadata(ctx, &model.HasMetadataSpec{ID: ptrfrom.String(nodeID.String())}) + hms, err := b.HasMetadata(ctx, &model.HasMetadataSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for HasMetadata via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for HasMetadata via ID: %s, with error: %w", foundGlobalID.id, err) } if len(hms) != 1 { - return nil, fmt.Errorf("ID returned multiple HasMetadata nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple HasMetadata nodes %s", foundGlobalID.id) } return hms[0], nil case billofmaterials.Table: - hbs, err := b.HasSBOM(ctx, &model.HasSBOMSpec{ID: ptrfrom.String(nodeID.String())}) + hbs, err := b.HasSBOM(ctx, &model.HasSBOMSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for HasSBOM via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for HasSBOM via ID: %s, with error: %w", foundGlobalID.id, err) } if len(hbs) != 1 { - return nil, fmt.Errorf("ID returned multiple HasSBOM nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple HasSBOM nodes %s", foundGlobalID.id) } return hbs[0], nil case slsaattestation.Table: - slsas, err := b.HasSlsa(ctx, &model.HasSLSASpec{ID: ptrfrom.String(nodeID.String())}) + slsas, err := b.HasSlsa(ctx, &model.HasSLSASpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for HasSlsa via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for HasSlsa via ID: %s, with error: %w", foundGlobalID.id, err) } if len(slsas) != 1 { - return nil, fmt.Errorf("ID returned multiple HasSlsa nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple HasSlsa nodes %s", foundGlobalID.id) } return slsas[0], nil case hassourceat.Table: - hsas, err := b.HasSourceAt(ctx, &model.HasSourceAtSpec{ID: ptrfrom.String(nodeID.String())}) + hsas, err := b.HasSourceAt(ctx, &model.HasSourceAtSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for HasSourceAt via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for HasSourceAt via ID: %s, with error: %w", foundGlobalID.id, err) } if len(hsas) != 1 { - return nil, fmt.Errorf("ID returned multiple HasSourceAt nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple HasSourceAt nodes %s", foundGlobalID.id) } return hsas[0], nil case dependency.Table: - deps, err := b.IsDependency(ctx, &model.IsDependencySpec{ID: ptrfrom.String(nodeID.String())}) + deps, err := b.IsDependency(ctx, &model.IsDependencySpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for IsDependency via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for IsDependency via ID: %s, with error: %w", foundGlobalID.id, err) } if len(deps) != 1 { - return nil, fmt.Errorf("ID returned multiple IsDependency nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple IsDependency nodes %s", foundGlobalID.id) } return deps[0], nil case occurrence.Table: - occurs, err := b.IsOccurrence(ctx, &model.IsOccurrenceSpec{ID: ptrfrom.String(nodeID.String())}) + occurs, err := b.IsOccurrence(ctx, &model.IsOccurrenceSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for IsOccurrence via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for IsOccurrence via ID: %s, with error: %w", foundGlobalID.id, err) } if len(occurs) != 1 { - return nil, fmt.Errorf("ID returned multiple IsOccurrence nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple IsOccurrence nodes %s", foundGlobalID.id) } return occurs[0], nil case pkgequal.Table: - pes, err := b.PkgEqual(ctx, &model.PkgEqualSpec{ID: ptrfrom.String(nodeID.String())}) + pes, err := b.PkgEqual(ctx, &model.PkgEqualSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for PkgEqual via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for PkgEqual via ID: %s, with error: %w", foundGlobalID.id, err) } if len(pes) != 1 { - return nil, fmt.Errorf("ID returned multiple PkgEqual nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple PkgEqual nodes %s", foundGlobalID.id) } return pes[0], nil case pointofcontact.Table: - pocs, err := b.PointOfContact(ctx, &model.PointOfContactSpec{ID: ptrfrom.String(nodeID.String())}) + pocs, err := b.PointOfContact(ctx, &model.PointOfContactSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for PointOfContact via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for PointOfContact via ID: %s, with error: %w", foundGlobalID.id, err) } if len(pocs) != 1 { - return nil, fmt.Errorf("ID returned multiple PointOfContact nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple PointOfContact nodes %s", foundGlobalID.id) } return pocs[0], nil case vulnequal.Table: - ves, err := b.VulnEqual(ctx, &model.VulnEqualSpec{ID: ptrfrom.String(nodeID.String())}) + ves, err := b.VulnEqual(ctx, &model.VulnEqualSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for VulnEqual via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for VulnEqual via ID: %s, with error: %w", foundGlobalID.id, err) } if len(ves) != 1 { - return nil, fmt.Errorf("ID returned multiple VulnEqual nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple VulnEqual nodes %s", foundGlobalID.id) } return ves[0], nil case vulnerabilitymetadata.Table: - vms, err := b.VulnerabilityMetadata(ctx, &model.VulnerabilityMetadataSpec{ID: ptrfrom.String(nodeID.String())}) + vms, err := b.VulnerabilityMetadata(ctx, &model.VulnerabilityMetadataSpec{ID: ptrfrom.String(foundGlobalID.id)}) if err != nil { - return nil, fmt.Errorf("failed to query for VulnerabilityMetadata via ID: %s, with error: %w", nodeID.String(), err) + return nil, fmt.Errorf("failed to query for VulnerabilityMetadata via ID: %s, with error: %w", foundGlobalID.id, err) } if len(vms) != 1 { - return nil, fmt.Errorf("ID returned multiple VulnerabilityMetadata nodes %s", nodeID.String()) + return nil, fmt.Errorf("ID returned multiple VulnerabilityMetadata nodes %s", foundGlobalID.id) } return vms[0], nil default: diff --git a/pkg/assembler/backends/ent/backend/package.go b/pkg/assembler/backends/ent/backend/package.go index fc770ed35f..c0cc889afa 100644 --- a/pkg/assembler/backends/ent/backend/package.go +++ b/pkg/assembler/backends/ent/backend/package.go @@ -467,6 +467,67 @@ func getPkgVersion(ctx context.Context, client *ent.Client, pkgin model.PkgInput return client.PackageVersion.Query().Where(packageVersionInputQuery(pkgin)).Only(ctx) } +func (b *EntBackend) getPkgNameSpace(ctx context.Context, nodeID string) (*model.Package, error) { + // split to find the type and namespace value + splitQueryValue := strings.Split(nodeID, guacIDSplit) + if len(splitQueryValue) != 2 { + return nil, fmt.Errorf("invalid query for packageNamespaceNeighbors with ID %s", nodeID) + } + query := b.client.PackageName.Query(). + Where([]predicate.PackageName{ + optionalPredicate(&splitQueryValue[0], packagename.TypeEQ), + optionalPredicate(&splitQueryValue[1], packagename.NamespaceEQ), + }...) + pn, err := query.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get packageNamespace for node ID: %s with error: %w", nodeID, err) + } + + if len(pn) > 0 { + pkgNamespace := &model.Package{ + ID: pkgTypeGlobalID(pn[0].Type), + Type: pn[0].Type, + Namespaces: []*model.PackageNamespace{ + { + ID: pkgNamespaceGlobalID(strings.Join([]string{pn[0].Type, pn[0].Namespace}, ":")), + Namespace: pn[0].Namespace, + Names: []*model.PackageName{{ + ID: pkgNameGlobalID(pn[0].ID.String()), + Name: pn[0].Name, + Versions: []*model.PackageVersion{}, + }}, + }, + }, + } + return pkgNamespace, nil + } else { + return nil, fmt.Errorf("failed to get packageNamespace for node ID: %s", nodeID) + } +} + +func (b *EntBackend) getPkgType(ctx context.Context, nodeID string) (*model.Package, error) { + query := b.client.PackageName.Query(). + Where([]predicate.PackageName{ + optionalPredicate(&nodeID, packagename.TypeEQ), + }...) + + pn, err := query.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get pkgType for node ID: %s with error: %w", nodeID, err) + } + + if len(pn) > 0 { + pkgType := &model.Package{ + ID: pkgTypeGlobalID(pn[0].Type), + Type: pn[0].Type, + Namespaces: []*model.PackageNamespace{}, + } + return pkgType, nil + } else { + return nil, fmt.Errorf("failed to get package type for node ID: %s", nodeID) + } +} + func (b *EntBackend) packageTypeNeighbors(ctx context.Context, nodeID string, allowedEdges edgeMap) ([]model.Node, error) { var out []model.Node if allowedEdges[model.EdgePackageTypePackageNamespace] { diff --git a/pkg/assembler/backends/ent/backend/source.go b/pkg/assembler/backends/ent/backend/source.go index 27657165a8..928fb0e1b6 100644 --- a/pkg/assembler/backends/ent/backend/source.go +++ b/pkg/assembler/backends/ent/backend/source.go @@ -698,6 +698,60 @@ func getSourceNameID(ctx context.Context, client *ent.Client, s model.SourceInpu return client.SourceName.Query().Where(sourceInputQuery(s)).OnlyID(ctx) } +func (b *EntBackend) getSrcNameSpace(ctx context.Context, nodeID string) (*model.Source, error) { + // split to find the type and namespace value + splitQueryValue := strings.Split(nodeID, guacIDSplit) + if len(splitQueryValue) != 2 { + return nil, fmt.Errorf("invalid query for sourceNamespace with ID %s", nodeID) + } + + query := b.client.SourceName.Query(). + Where(sourceQuery(&model.SourceSpec{Type: &splitQueryValue[0], Namespace: &splitQueryValue[1]})) + + sn, err := query.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query for sourceNamespace with node ID: %s with error: %w", nodeID, err) + } + + if len(sn) > 0 { + srcNamespace := &model.Source{ + ID: srcTypeGlobalID(sn[0].Type), + Type: sn[0].Type, + Namespaces: []*model.SourceNamespace{ + { + ID: srcNamespaceGlobalID(strings.Join([]string{sn[0].Type, sn[0].Namespace}, guacIDSplit)), + Namespace: sn[0].Namespace, + Names: []*model.SourceName{}, + }, + }, + } + return srcNamespace, nil + } else { + return nil, fmt.Errorf("failed to get sourceNamespace for node ID: %s", nodeID) + } +} + +func (b *EntBackend) getSrcType(ctx context.Context, nodeID string) (*model.Source, error) { + query := b.client.SourceName.Query(). + Where(sourceQuery(&model.SourceSpec{Type: &nodeID})) + + sn, err := query.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get source type for node ID: %s with error: %w", nodeID, err) + } + + if len(sn) > 0 { + srcType := &model.Source{ + ID: srcTypeGlobalID(sn[0].Type), + Type: sn[0].Type, + Namespaces: []*model.SourceNamespace{}, + } + return srcType, nil + } else { + return nil, fmt.Errorf("failed to get source type for node ID: %s", nodeID) + } +} + func (b *EntBackend) srcTypeNeighbors(ctx context.Context, nodeID string, allowedEdges edgeMap) ([]model.Node, error) { var out []model.Node if allowedEdges[model.EdgeSourceTypeSourceNamespace] { diff --git a/pkg/assembler/backends/ent/backend/vulnerability.go b/pkg/assembler/backends/ent/backend/vulnerability.go index 242b41cfee..9bfc262346 100644 --- a/pkg/assembler/backends/ent/backend/vulnerability.go +++ b/pkg/assembler/backends/ent/backend/vulnerability.go @@ -285,6 +285,27 @@ func toModelVulnerabilityID(vulnID *ent.VulnerabilityID) *model.VulnerabilityID } } +func (b *EntBackend) getVulnType(ctx context.Context, nodeID string) (*model.Vulnerability, error) { + query := b.client.VulnerabilityID.Query(). + Where(vulnerabilityQueryPredicates(model.VulnerabilitySpec{Type: &nodeID})...) + + vulnIDs, err := query.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get vuln type for node ID: %s with error: %w", nodeID, err) + } + + if len(vulnIDs) > 0 { + vulnType := &model.Vulnerability{ + ID: vulnTypeGlobalID(vulnIDs[0].Type), + Type: vulnIDs[0].Type, + VulnerabilityIDs: []*model.VulnerabilityID{}, + } + return vulnType, nil + } else { + return nil, fmt.Errorf("failed to get vuln type for node ID: %s", nodeID) + } +} + func (b *EntBackend) vulnTypeNeighbors(ctx context.Context, nodeID string, allowedEdges edgeMap) ([]model.Node, error) { var out []model.Node if allowedEdges[model.EdgeVulnerabilityTypeVulnerabilityID] { From fe91d957c8e1b44b13771f556bf3342397c05928 Mon Sep 17 00:00:00 2001 From: pxp928 Date: Mon, 1 Jul 2024 16:27:07 -0400 Subject: [PATCH 2/3] add missing unit tests for backend for pkg/src/vuln name, namespace and type Signed-off-by: pxp928 --- internal/testing/backend/path_test.go | 138 ++++++++++++++++-- pkg/assembler/backends/ent/backend/package.go | 6 +- 2 files changed, 125 insertions(+), 19 deletions(-) diff --git a/internal/testing/backend/path_test.go b/internal/testing/backend/path_test.go index beaf1deed9..609e61267b 100644 --- a/internal/testing/backend/path_test.go +++ b/internal/testing/backend/path_test.go @@ -331,11 +331,17 @@ func TestNodes(t *testing.T) { } tests := []struct { name string - pkgInput *model.PkgInputSpec + pkgVersionInput *model.PkgInputSpec + pkgNameInput *model.PkgInputSpec + pkgNamespaceInput *model.PkgInputSpec + pkgTypeInput *model.PkgInputSpec artifactInput *model.ArtifactInputSpec builderInput *model.BuilderInputSpec - srcInput *model.SourceInputSpec + srcNameInput *model.SourceInputSpec + srcNamespaceInput *model.SourceInputSpec + srcTypeInput *model.SourceInputSpec vulnInput *model.VulnerabilityInputSpec + vulnTypeInput *model.VulnerabilityInputSpec licenseInput *model.LicenseInputSpec inPkg []*model.PkgInputSpec inSrc []*model.SourceInputSpec @@ -363,10 +369,41 @@ func TestNodes(t *testing.T) { want []model.Node wantErr bool }{{ - name: "package", - pkgInput: testdata.P1, - want: []model.Node{testdata.P1out}, - wantErr: false, + name: "package version", + pkgVersionInput: testdata.P1, + want: []model.Node{testdata.P1out}, + wantErr: false, + }, { + name: "package name", + pkgNameInput: testdata.P1, + want: []model.Node{&model.Package{ + Type: "pypi", + Namespaces: []*model.PackageNamespace{{ + Names: []*model.PackageName{{ + Name: "tensorflow", + Versions: []*model.PackageVersion{}, + }}, + }}, + }}, + wantErr: false, + }, { + name: "package namespace", + pkgNamespaceInput: testdata.P1, + want: []model.Node{&model.Package{ + Type: "pypi", + Namespaces: []*model.PackageNamespace{{ + Names: []*model.PackageName{}, + }}, + }}, + wantErr: false, + }, { + name: "package type", + pkgTypeInput: testdata.P1, + want: []model.Node{&model.Package{ + Type: "pypi", + Namespaces: []*model.PackageNamespace{}, + }}, + wantErr: false, }, { name: "artifact", artifactInput: &model.ArtifactInputSpec{ @@ -388,10 +425,29 @@ func TestNodes(t *testing.T) { }}, wantErr: false, }, { - name: "source", - srcInput: testdata.S1, - want: []model.Node{testdata.S1out}, - wantErr: false, + name: "source name", + srcNameInput: testdata.S1, + want: []model.Node{testdata.S1out}, + wantErr: false, + }, { + name: "source namespace", + srcNamespaceInput: testdata.S1, + want: []model.Node{&model.Source{ + Type: "git", + Namespaces: []*model.SourceNamespace{{ + Namespace: "github.com/jeff", + Names: []*model.SourceName{}, + }}, + }}, + wantErr: false, + }, { + name: "source type", + srcTypeInput: testdata.S1, + want: []model.Node{&model.Source{ + Type: "git", + Namespaces: []*model.SourceNamespace{}, + }}, + wantErr: false, }, { name: "vulnerability", vulnInput: testdata.C1, @@ -399,6 +455,13 @@ func TestNodes(t *testing.T) { Type: "cve", VulnerabilityIDs: []*model.VulnerabilityID{testdata.C1out}, }}, + }, { + name: "vulnerability type", + vulnTypeInput: testdata.C1, + want: []model.Node{&model.Vulnerability{ + Type: "cve", + VulnerabilityIDs: []*model.VulnerabilityID{}, + }}, }, { name: "license", licenseInput: testdata.L1, @@ -738,14 +801,38 @@ func TestNodes(t *testing.T) { t.Fatalf("Could not ingest vulnerability: %a", err) } } - if tt.pkgInput != nil { - ingestedPkg, err := b.IngestPackage(ctx, model.IDorPkgInput{PackageInput: tt.pkgInput}) + if tt.pkgVersionInput != nil { + ingestedPkg, err := b.IngestPackage(ctx, model.IDorPkgInput{PackageInput: tt.pkgVersionInput}) if (err != nil) != tt.wantErr { t.Errorf("arangoClient.IngestPackage() error = %v, wantErr %v", err, tt.wantErr) return } nodeID = ingestedPkg.PackageVersionID } + if tt.pkgNameInput != nil { + ingestedPkg, err := b.IngestPackage(ctx, model.IDorPkgInput{PackageInput: tt.pkgNameInput}) + if (err != nil) != tt.wantErr { + t.Errorf("arangoClient.IngestPackage() error = %v, wantErr %v", err, tt.wantErr) + return + } + nodeID = ingestedPkg.PackageNameID + } + if tt.pkgNamespaceInput != nil { + ingestedPkg, err := b.IngestPackage(ctx, model.IDorPkgInput{PackageInput: tt.pkgNamespaceInput}) + if (err != nil) != tt.wantErr { + t.Errorf("arangoClient.IngestPackage() error = %v, wantErr %v", err, tt.wantErr) + return + } + nodeID = ingestedPkg.PackageNamespaceID + } + if tt.pkgTypeInput != nil { + ingestedPkg, err := b.IngestPackage(ctx, model.IDorPkgInput{PackageInput: tt.pkgTypeInput}) + if (err != nil) != tt.wantErr { + t.Errorf("arangoClient.IngestPackage() error = %v, wantErr %v", err, tt.wantErr) + return + } + nodeID = ingestedPkg.PackageTypeID + } if tt.artifactInput != nil { ingestedArtID, err := b.IngestArtifact(ctx, &model.IDorArtifactInput{ArtifactInput: tt.artifactInput}) if (err != nil) != tt.wantErr { @@ -762,14 +849,30 @@ func TestNodes(t *testing.T) { } nodeID = ingestedBuilderID } - if tt.srcInput != nil { - ingestedSrc, err := b.IngestSource(ctx, model.IDorSourceInput{SourceInput: tt.srcInput}) + if tt.srcNameInput != nil { + ingestedSrc, err := b.IngestSource(ctx, model.IDorSourceInput{SourceInput: tt.srcNameInput}) if (err != nil) != tt.wantErr { t.Errorf("arangoClient.IngestSource() error = %v, wantErr %v", err, tt.wantErr) return } nodeID = ingestedSrc.SourceNameID } + if tt.srcNamespaceInput != nil { + ingestedSrc, err := b.IngestSource(ctx, model.IDorSourceInput{SourceInput: tt.srcNamespaceInput}) + if (err != nil) != tt.wantErr { + t.Errorf("arangoClient.IngestSource() error = %v, wantErr %v", err, tt.wantErr) + return + } + nodeID = ingestedSrc.SourceNamespaceID + } + if tt.srcTypeInput != nil { + ingestedSrc, err := b.IngestSource(ctx, model.IDorSourceInput{SourceInput: tt.srcTypeInput}) + if (err != nil) != tt.wantErr { + t.Errorf("arangoClient.IngestSource() error = %v, wantErr %v", err, tt.wantErr) + return + } + nodeID = ingestedSrc.SourceTypeID + } if tt.vulnInput != nil { ingestVuln, err := b.IngestVulnerability(ctx, model.IDorVulnerabilityInput{VulnerabilityInput: tt.vulnInput}) if (err != nil) != tt.wantErr { @@ -777,6 +880,13 @@ func TestNodes(t *testing.T) { } nodeID = ingestVuln.VulnerabilityNodeID } + if tt.vulnTypeInput != nil { + ingestVuln, err := b.IngestVulnerability(ctx, model.IDorVulnerabilityInput{VulnerabilityInput: tt.vulnTypeInput}) + if (err != nil) != tt.wantErr { + t.Fatalf("did not get expected ingest error, want: %v, got: %v", tt.want, err) + } + nodeID = ingestVuln.VulnerabilityTypeID + } if tt.licenseInput != nil { ingestedLicenseID, err := b.IngestLicense(ctx, &model.IDorLicenseInput{LicenseInput: tt.licenseInput}) if (err != nil) != tt.wantErr { diff --git a/pkg/assembler/backends/ent/backend/package.go b/pkg/assembler/backends/ent/backend/package.go index c0cc889afa..710527e56f 100644 --- a/pkg/assembler/backends/ent/backend/package.go +++ b/pkg/assembler/backends/ent/backend/package.go @@ -491,11 +491,7 @@ func (b *EntBackend) getPkgNameSpace(ctx context.Context, nodeID string) (*model { ID: pkgNamespaceGlobalID(strings.Join([]string{pn[0].Type, pn[0].Namespace}, ":")), Namespace: pn[0].Namespace, - Names: []*model.PackageName{{ - ID: pkgNameGlobalID(pn[0].ID.String()), - Name: pn[0].Name, - Versions: []*model.PackageVersion{}, - }}, + Names: []*model.PackageName{}, }, }, } From 573a206cae94740554b3d06fdb2c6d8f68aeaa93 Mon Sep 17 00:00:00 2001 From: pxp928 Date: Tue, 2 Jul 2024 09:16:58 -0400 Subject: [PATCH 3/3] change colon to guacIDSplit Signed-off-by: pxp928 --- pkg/assembler/backends/ent/backend/package.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/assembler/backends/ent/backend/package.go b/pkg/assembler/backends/ent/backend/package.go index 710527e56f..f5fec55b35 100644 --- a/pkg/assembler/backends/ent/backend/package.go +++ b/pkg/assembler/backends/ent/backend/package.go @@ -489,7 +489,7 @@ func (b *EntBackend) getPkgNameSpace(ctx context.Context, nodeID string) (*model Type: pn[0].Type, Namespaces: []*model.PackageNamespace{ { - ID: pkgNamespaceGlobalID(strings.Join([]string{pn[0].Type, pn[0].Namespace}, ":")), + ID: pkgNamespaceGlobalID(strings.Join([]string{pn[0].Type, pn[0].Namespace}, guacIDSplit)), Namespace: pn[0].Namespace, Names: []*model.PackageName{}, }, @@ -580,7 +580,7 @@ func (b *EntBackend) packageNamespaceNeighbors(ctx context.Context, nodeID strin Type: foundPkgName.Type, Namespaces: []*model.PackageNamespace{ { - ID: pkgNamespaceGlobalID(strings.Join([]string{foundPkgName.Type, foundPkgName.Namespace}, ":")), + ID: pkgNamespaceGlobalID(strings.Join([]string{foundPkgName.Type, foundPkgName.Namespace}, guacIDSplit)), Namespace: foundPkgName.Namespace, Names: []*model.PackageName{{ ID: pkgNameGlobalID(foundPkgName.ID.String()),