peridot/apollo/db/mock/mock.go
2022-10-30 02:59:43 +01:00

763 lines
20 KiB
Go

// Copyright (c) All respective contributors to the Peridot Project. All rights reserved.
// Copyright (c) 2021-2022 Rocky Enterprise Software Foundation, Inc. All rights reserved.
// Copyright (c) 2021-2022 Ctrl IQ, Inc. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors
// may be used to endorse or promote products derived from this software without
// specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
package apollomock
import (
"database/sql"
"fmt"
"github.com/jmoiron/sqlx/types"
apollodb "peridot.resf.org/apollo/db"
apollopb "peridot.resf.org/apollo/pb"
"peridot.resf.org/utils"
"time"
)
type Access struct {
ShortCodes []*apollodb.ShortCode
Advisories []*apollodb.Advisory
Cves []*apollodb.CVE
Fixes []*apollodb.Fix
AdvisoryReferences []*apollodb.AdvisoryReference
Products []*apollodb.Product
AffectedProducts []*apollodb.AffectedProduct
BuildReferences []*apollodb.BuildReference
MirrorStates []*apollodb.MirrorState
AdvisoryCVEs []*apollodb.AdvisoryCVE
AdvisoryFixes []*apollodb.AdvisoryFix
IgnoredUpstreamPackages []*apollodb.IgnoredUpstreamPackage
RebootSuggestedPackages []*apollodb.RebootSuggestedPackage
AdvisoryRPMs []*apollodb.AdvisoryRPM
}
func New() *Access {
return &Access{
ShortCodes: []*apollodb.ShortCode{},
Advisories: []*apollodb.Advisory{},
Cves: []*apollodb.CVE{},
Fixes: []*apollodb.Fix{},
AdvisoryReferences: []*apollodb.AdvisoryReference{},
Products: []*apollodb.Product{},
AffectedProducts: []*apollodb.AffectedProduct{},
BuildReferences: []*apollodb.BuildReference{},
MirrorStates: []*apollodb.MirrorState{},
AdvisoryCVEs: []*apollodb.AdvisoryCVE{},
AdvisoryFixes: []*apollodb.AdvisoryFix{},
IgnoredUpstreamPackages: []*apollodb.IgnoredUpstreamPackage{},
RebootSuggestedPackages: []*apollodb.RebootSuggestedPackage{},
AdvisoryRPMs: []*apollodb.AdvisoryRPM{},
}
}
func (a *Access) GetAllShortCodes() ([]*apollodb.ShortCode, error) {
return a.ShortCodes, nil
}
func (a *Access) GetShortCodeByCode(code string) (*apollodb.ShortCode, error) {
for _, val := range a.ShortCodes {
if val.Code == code {
return val, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) CreateShortCode(code string, mode apollopb.ShortCode_Mode) (*apollodb.ShortCode, error) {
now := time.Now()
shortCode := apollodb.ShortCode{
Code: code,
Mode: int8(mode),
CreatedAt: &now,
ArchivedAt: sql.NullTime{},
}
a.ShortCodes = append(a.ShortCodes, &shortCode)
return &shortCode, nil
}
func (a *Access) getAdvisoriesWithJoin(filter func(*apollodb.Advisory) bool) []*apollodb.Advisory {
var advisories []*apollodb.Advisory
for _, val := range a.Advisories {
if filter(val) {
advisories = append(advisories, val)
}
}
if len(advisories) == 0 {
return advisories
}
for _, advisory := range advisories {
advisory.AffectedProducts = []string{}
advisory.Fixes = []string{}
advisory.Cves = []string{}
advisory.References = []string{}
advisory.RPMs = []string{}
advisory.BuildArtifacts = []string{}
for _, advisoryCve := range a.AdvisoryCVEs {
if advisoryCve.AdvisoryID != advisory.ID {
continue
}
for _, buildReference := range a.BuildReferences {
if buildReference.CveID == advisoryCve.CveID {
advisory.BuildArtifacts = append(advisory.BuildArtifacts, fmt.Sprintf("%s:::%s", buildReference.Rpm, buildReference.SrcRpm))
}
}
for _, cve := range a.Cves {
if cve.ID == advisoryCve.CveID {
cveString := fmt.Sprintf("%s:::%s:::%s", cve.SourceBy.String, cve.SourceLink.String, cve.ID)
if !utils.StrContains(cveString, advisory.Cves) {
advisory.Cves = append(advisory.Cves, cveString)
}
}
}
for _, val := range a.AffectedProducts {
if val.CveID.String == advisoryCve.CveID {
for _, product := range a.Products {
if val.ProductID == product.ID {
if !utils.StrContains(product.Name, advisory.AffectedProducts) {
advisory.AffectedProducts = append(advisory.AffectedProducts, product.Name)
}
}
}
}
}
}
for _, advisoryFix := range a.AdvisoryFixes {
if advisoryFix.AdvisoryID != advisory.ID {
continue
}
for _, fix := range a.Fixes {
if fix.ID == advisoryFix.FixID {
if !utils.StrContains(fix.Ticket.String, advisory.Fixes) {
advisory.Fixes = append(advisory.Fixes, fix.Ticket.String)
}
}
}
}
for _, advisoryReference := range a.AdvisoryReferences {
if advisoryReference.AdvisoryId != advisory.ID {
continue
}
if !utils.StrContains(advisoryReference.URL, advisory.References) {
advisory.References = append(advisory.References, advisoryReference.URL)
}
}
for _, advisoryRPM := range a.AdvisoryRPMs {
if advisoryRPM.AdvisoryID != advisory.ID {
continue
}
if !utils.StrContains(advisoryRPM.Name, advisory.RPMs) {
advisory.RPMs = append(advisory.RPMs, advisoryRPM.Name)
}
}
}
return advisories
}
func (a *Access) GetAllAdvisories(filters *apollopb.AdvisoryFilters, page int32, limit int32) ([]*apollodb.Advisory, error) {
return a.getAdvisoriesWithJoin(func(advisory *apollodb.Advisory) bool {
if filters.Product != nil {
if !utils.StrContains(filters.Product.Value, advisory.AffectedProducts) {
return false
}
}
if advisory.PublishedAt.Valid {
if filters.Before != nil {
if advisory.PublishedAt.Time.After(filters.Before.AsTime()) {
return false
}
}
if filters.After != nil {
if advisory.PublishedAt.Time.Before(filters.After.AsTime()) {
return false
}
}
}
if filters.IncludeUnpublished != nil {
if !filters.IncludeUnpublished.Value && !advisory.PublishedAt.Valid {
return false
}
} else {
if !advisory.PublishedAt.Valid {
return false
}
}
return true
}), nil
}
func (a *Access) GetAdvisoryByCodeAndYearAndNum(code string, year int, num int) (*apollodb.Advisory, error) {
advisories := a.getAdvisoriesWithJoin(func(advisory *apollodb.Advisory) bool {
if advisory.ShortCodeCode == code && advisory.Year == year && advisory.Num == num {
return true
}
return false
})
if len(advisories) == 0 {
return nil, sql.ErrNoRows
}
return advisories[0], nil
}
func (a *Access) CreateAdvisory(advisory *apollodb.Advisory) (*apollodb.Advisory, error) {
var lastId int64 = 1
if len(a.Advisories) > 0 {
lastId = a.Advisories[len(a.Advisories)-1].ID + 1
}
now := time.Now()
ret := &apollodb.Advisory{
ID: lastId,
CreatedAt: &now,
Year: advisory.Year,
Num: advisory.Num,
Synopsis: advisory.Synopsis,
Topic: advisory.Topic,
Severity: advisory.Severity,
Type: advisory.Type,
Description: advisory.Description,
Solution: advisory.Solution,
RedHatIssuedAt: advisory.RedHatIssuedAt,
ShortCodeCode: advisory.ShortCodeCode,
PublishedAt: advisory.PublishedAt,
}
return ret, nil
}
func (a *Access) UpdateAdvisory(advisory *apollodb.Advisory) (*apollodb.Advisory, error) {
for _, val := range a.Advisories {
if val.ID == advisory.ID {
val.Year = advisory.Year
val.Num = advisory.Num
val.Synopsis = advisory.Synopsis
val.Topic = advisory.Topic
val.Severity = advisory.Severity
val.Type = advisory.Type
val.Description = advisory.Description
val.Solution = advisory.Solution
val.ShortCodeCode = advisory.ShortCodeCode
val.PublishedAt = advisory.PublishedAt
return val, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) GetAllUnresolvedCVEs() ([]*apollodb.CVE, error) {
var cves []*apollodb.CVE
var addedCVEIds []string
for _, cve := range a.Cves {
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.CveID.String == cve.ID {
switch affectedProduct.State {
case
int(apollopb.AffectedProduct_STATE_UNDER_INVESTIGATION_UPSTREAM),
int(apollopb.AffectedProduct_STATE_UNDER_INVESTIGATION_DOWNSTREAM),
int(apollopb.AffectedProduct_STATE_AFFECTED_UPSTREAM),
int(apollopb.AffectedProduct_STATE_AFFECTED_DOWNSTREAM):
nCve := *cve
nCve.AffectedProductId = sql.NullInt64{Valid: true, Int64: affectedProduct.ID}
cves = append(cves, &nCve)
break
}
}
}
}
for _, cve := range a.Cves {
if !utils.StrContains(cve.ID, addedCVEIds) {
cves = append(cves, cve)
}
}
return cves, nil
}
func (a *Access) GetPendingAffectedProducts() ([]*apollodb.AffectedProduct, error) {
var ret []*apollodb.AffectedProduct
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.State == int(apollopb.AffectedProduct_STATE_FIXED_UPSTREAM) {
ret = append(ret, affectedProduct)
}
}
return ret, nil
}
func (a *Access) GetAllCVEsFixedDownstream() ([]*apollodb.CVE, error) {
var cves []*apollodb.CVE
for _, cve := range a.Cves {
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.CveID.String == cve.ID {
if affectedProduct.State == int(apollopb.AffectedProduct_STATE_FIXED_DOWNSTREAM) {
nCve := *cve
nCve.AffectedProductId = sql.NullInt64{Valid: true, Int64: affectedProduct.ID}
cves = append(cves, &nCve)
break
}
}
}
}
return cves, nil
}
func (a *Access) GetCVEByID(id string) (*apollodb.CVE, error) {
for _, cve := range a.Cves {
if cve.ID == id {
return cve, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) GetAllCVEs() ([]*apollodb.CVE, error) {
return a.Cves, nil
}
func (a *Access) CreateCVE(cveId string, shortCode string, sourceBy *string, sourceLink *string, content types.NullJSONText) (*apollodb.CVE, error) {
var sby sql.NullString
var sl sql.NullString
if sourceBy != nil {
sby.String = *sourceBy
sby.Valid = true
}
if sourceLink != nil {
sl.String = *sourceLink
sl.Valid = true
}
now := time.Now()
cve := &apollodb.CVE{
ID: cveId,
CreatedAt: &now,
AdvisoryId: sql.NullInt64{},
ShortCode: shortCode,
SourceBy: sby,
SourceLink: sl,
Content: content,
}
a.Cves = append(a.Cves, cve)
return cve, nil
}
func (a *Access) SetCVEContent(cveId string, content types.JSONText) error {
for _, cve := range a.Cves {
if cve.ID == cveId {
cve.Content = types.NullJSONText{Valid: true, JSONText: content}
return nil
}
}
return sql.ErrNoRows
}
func (a *Access) GetProductsByShortCode(code string) ([]*apollodb.Product, error) {
var products []*apollodb.Product
for _, product := range a.Products {
if product.ShortCode == code {
products = append(products, product)
}
}
return products, nil
}
func (a *Access) GetProductByNameAndShortCode(name string, code string) (*apollodb.Product, error) {
for _, product := range a.Products {
if product.Name == name && product.ShortCode == code {
return product, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) GetProductByID(id int64) (*apollodb.Product, error) {
for _, product := range a.Products {
if product.ID == id {
return product, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) CreateProduct(name string, currentFullVersion string, redHatMajorVersion *int32, code string, archs []string) (*apollodb.Product, error) {
var lastId int64 = 1
if len(a.Products) > 0 {
lastId = a.Products[len(a.Products)-1].ID + 1
}
var rhmv sql.NullInt32
if redHatMajorVersion != nil {
rhmv.Int32 = *redHatMajorVersion
rhmv.Valid = true
}
product := &apollodb.Product{
ID: lastId,
Name: name,
CurrentFullVersion: currentFullVersion,
RedHatMajorVersion: rhmv,
ShortCode: code,
Archs: archs,
MirrorFromDate: sql.NullTime{},
RedHatProductPrefix: sql.NullString{},
}
a.Products = append(a.Products, product)
return product, nil
}
func (a *Access) GetAllAffectedProductsByCVE(cve string) ([]*apollodb.AffectedProduct, error) {
var affectedProducts []*apollodb.AffectedProduct
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.CveID.String == cve {
affectedProducts = append(affectedProducts, affectedProduct)
}
}
return affectedProducts, nil
}
func (a *Access) GetAffectedProductByCVEAndPackage(cve string, pkg string) (*apollodb.AffectedProduct, error) {
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.CveID.String == cve && affectedProduct.Package == pkg {
return affectedProduct, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) GetAffectedProductByAdvisory(advisory string) (*apollodb.AffectedProduct, error) {
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.Advisory.String == advisory {
return affectedProduct, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) GetAffectedProductByID(id int64) (*apollodb.AffectedProduct, error) {
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.ID == id {
return affectedProduct, nil
}
}
return nil, sql.ErrNoRows
}
func (a *Access) CreateAffectedProduct(productId int64, cveId string, state int, version string, pkg string, advisory *string) (*apollodb.AffectedProduct, error) {
var lastId int64 = 1
if len(a.AffectedProducts) > 0 {
lastId = a.AffectedProducts[len(a.AffectedProducts)-1].ID + 1
}
var adv sql.NullString
if advisory != nil {
adv.String = *advisory
adv.Valid = true
}
affectedProduct := &apollodb.AffectedProduct{
ID: lastId,
ProductID: productId,
CveID: sql.NullString{Valid: true, String: cveId},
State: state,
Version: version,
Package: pkg,
Advisory: adv,
}
a.AffectedProducts = append(a.AffectedProducts, affectedProduct)
return affectedProduct, nil
}
func (a *Access) UpdateAffectedProductStateAndPackageAndAdvisory(id int64, state int, pkg string, advisory *string) error {
for _, affectedProduct := range a.AffectedProducts {
if affectedProduct.ID == id {
affectedProduct.State = state
affectedProduct.Package = pkg
var adv sql.NullString
if advisory != nil {
adv.String = *advisory
adv.Valid = true
}
affectedProduct.Advisory = adv
return nil
}
}
return sql.ErrNoRows
}
func (a *Access) DeleteAffectedProduct(id int64) error {
var index *int
for i, affectedProduct := range a.AffectedProducts {
if affectedProduct.ID == id {
index = &i
}
}
if index == nil {
return sql.ErrNoRows
}
a.AffectedProducts = append(a.AffectedProducts[:*index], a.AffectedProducts[*index+1:]...)
return nil
}
func (a *Access) CreateFix(ticket string, sourceBy string, sourceLink string, description string) (int64, error) {
var lastId int64 = 1
if len(a.Fixes) > 0 {
lastId = a.Fixes[len(a.Fixes)-1].ID + 1
}
fix := &apollodb.Fix{
ID: lastId,
Ticket: sql.NullString{Valid: true, String: ticket},
SourceBy: sql.NullString{Valid: true, String: sourceBy},
SourceLink: sql.NullString{Valid: true, String: sourceLink},
Description: sql.NullString{Valid: true, String: description},
}
a.Fixes = append(a.Fixes, fix)
return lastId, nil
}
func (a *Access) GetMirrorState(code string) (*apollodb.MirrorState, error) {
var lastSync *apollodb.MirrorState
for _, mirrorState := range a.MirrorStates {
if mirrorState.ShortCode == code {
if mirrorState.LastSync.Valid {
lastSync = mirrorState
}
}
}
if lastSync == nil {
return nil, sql.ErrNoRows
}
return lastSync, nil
}
func (a *Access) UpdateMirrorState(code string, lastSync *time.Time) error {
for _, mirrorState := range a.MirrorStates {
if mirrorState.ShortCode == code {
mirrorState.LastSync.Time = *lastSync
mirrorState.LastSync.Valid = true
return nil
}
}
mirrorState := &apollodb.MirrorState{
ShortCode: code,
LastSync: sql.NullTime{Valid: true, Time: *lastSync},
}
a.MirrorStates = append(a.MirrorStates, mirrorState)
return nil
}
func (a *Access) UpdateMirrorStateErrata(code string, lastSync *time.Time) error {
for _, mirrorState := range a.MirrorStates {
if mirrorState.ShortCode == code {
mirrorState.ErrataAfter.Time = *lastSync
mirrorState.ErrataAfter.Valid = true
return nil
}
}
mirrorState := &apollodb.MirrorState{
ShortCode: code,
ErrataAfter: sql.NullTime{Valid: true, Time: *lastSync},
}
a.MirrorStates = append(a.MirrorStates, mirrorState)
return nil
}
func (a *Access) GetMaxLastSync() (*time.Time, error) {
var maxLastSync *time.Time
for _, mirrorState := range a.MirrorStates {
if mirrorState.LastSync.Valid {
if maxLastSync == nil || mirrorState.LastSync.Time.After(*maxLastSync) {
maxLastSync = &mirrorState.LastSync.Time
}
}
}
if maxLastSync == nil {
return nil, sql.ErrNoRows
}
return maxLastSync, nil
}
func (a *Access) CreateBuildReference(affectedProductId int64, rpm string, srcRpm string, cveId string, sha256Sum string, kojiId *string, peridotId *string) (*apollodb.BuildReference, error) {
var lastId int64 = 1
if len(a.BuildReferences) > 0 {
lastId = a.BuildReferences[len(a.BuildReferences)-1].ID + 1
}
buildReference := &apollodb.BuildReference{
ID: lastId,
AffectedProductId: affectedProductId,
Rpm: rpm,
SrcRpm: srcRpm,
CveID: cveId,
Sha256Sum: sha256Sum,
}
if kojiId != nil {
buildReference.KojiID = sql.NullString{Valid: true, String: *kojiId}
}
if peridotId != nil {
buildReference.PeridotID = sql.NullString{Valid: true, String: *peridotId}
}
a.BuildReferences = append(a.BuildReferences, buildReference)
return buildReference, nil
}
func (a *Access) CreateAdvisoryReference(advisoryId int64, url string) error {
var lastId int64 = 1
if len(a.AdvisoryReferences) > 0 {
lastId = a.AdvisoryReferences[len(a.AdvisoryReferences)-1].ID + 1
}
advisoryReference := &apollodb.AdvisoryReference{
ID: lastId,
URL: url,
AdvisoryId: advisoryId,
}
a.AdvisoryReferences = append(a.AdvisoryReferences, advisoryReference)
return nil
}
func (a *Access) GetAllIgnoredPackagesByProductID(productID int64) ([]string, error) {
var packages []string
for _, ignoredPackage := range a.IgnoredUpstreamPackages {
if ignoredPackage.ProductID == productID {
packages = append(packages, ignoredPackage.Package)
}
}
return packages, nil
}
func (a *Access) GetAllRebootSuggestedPackages() ([]string, error) {
var packages []string
for _, p := range a.RebootSuggestedPackages {
packages = append(packages, p.Name)
}
return packages, nil
}
func (a *Access) AddAdvisoryFix(advisoryId int64, fixId int64) error {
advisoryFix := &apollodb.AdvisoryFix{
AdvisoryID: advisoryId,
FixID: fixId,
}
a.AdvisoryFixes = append(a.AdvisoryFixes, advisoryFix)
return nil
}
func (a *Access) AddAdvisoryCVE(advisoryId int64, cveId string) error {
advisoryCVE := &apollodb.AdvisoryCVE{
AdvisoryID: advisoryId,
CveID: cveId,
}
a.AdvisoryCVEs = append(a.AdvisoryCVEs, advisoryCVE)
return nil
}
func (a *Access) AddAdvisoryRPM(advisoryId int64, name string, productID int64) error {
advisoryRPM := &apollodb.AdvisoryRPM{
AdvisoryID: advisoryId,
Name: name,
ProductID: productID,
}
a.AdvisoryRPMs = append(a.AdvisoryRPMs, advisoryRPM)
return nil
}
func (a *Access) Begin() (utils.Tx, error) {
return &utils.MockTx{}, nil
}
func (a *Access) UseTransaction(_ utils.Tx) apollodb.Access {
return a
}