Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/controller/httpapi/v1/profiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func NewProfileRoutes(handler *gin.RouterGroup, t profiles.Feature, l logger.Int
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
_ = v.RegisterValidation("genpasswordwone", dto.ValidateAMTPassOrGenRan)
_ = v.RegisterValidation("ciraortls", dto.ValidateCIRAOrTLS)
_ = v.RegisterValidation("wifidhcp", dto.ValidateWiFiDHCP)
}
}

Expand Down
275 changes: 269 additions & 6 deletions internal/controller/httpapi/v1/profiles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"sync"
"testing"

"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"github.com/stretchr/testify/require"
gomock "go.uber.org/mock/gomock"

Expand All @@ -18,9 +22,84 @@ import (
"github.com/device-management-toolkit/console/pkg/logger"
)

// defaultValidator implements the gin binding.StructValidator interface
type defaultValidator struct {
once sync.Once
validate *validator.Validate
}

func (v *defaultValidator) ValidateStruct(obj any) error {
if obj == nil {
return nil
}

value := reflect.ValueOf(obj)
switch value.Kind() {
case reflect.Ptr:
if value.IsNil() {
return nil
}

return v.ValidateStruct(value.Elem().Interface())
case reflect.Struct:
return v.validateStruct(obj)
case reflect.Slice, reflect.Array:
count := value.Len()
validateRet := make(binding.SliceValidationError, 0)

for i := 0; i < count; i++ {
if err := v.ValidateStruct(value.Index(i).Interface()); err != nil {
validateRet = append(validateRet, err)
}
}

if len(validateRet) == 0 {
return nil
}

return validateRet
case reflect.Invalid, reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128,
reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.String, reflect.UnsafePointer:
return nil
default:
return nil
}
}

func (v *defaultValidator) validateStruct(obj any) error {
v.lazyinit()

return v.validate.Struct(obj)
}

func (v *defaultValidator) Engine() any {
v.lazyinit()

return v.validate
}

func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.New()
v.validate.SetTagName("binding")

// Register custom validators
_ = v.validate.RegisterValidation("genpasswordwone", dto.ValidateAMTPassOrGenRan)
_ = v.validate.RegisterValidation("ciraortls", dto.ValidateCIRAOrTLS)
_ = v.validate.RegisterValidation("wifidhcp", dto.ValidateWiFiDHCP)
})
}

func profilesTest(t *testing.T) (*mocks.MockProfilesFeature, *gin.Engine) {
t.Helper()

// Enable validation for tests
if binding.Validator == nil {
binding.Validator = &defaultValidator{}
}

mockCtl := gomock.NewController(t)
defer mockCtl.Finish()

Expand Down Expand Up @@ -73,9 +152,7 @@ var profileTest = dto.Profile{
UEFIWiFiSyncEnabled: false,
}

func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test function
t.Parallel()

func TestProfileRoutes(t *testing.T) { //nolint:gocognit,paralleltest // this is a test function
tests := []testProfiles{
{
name: "get all profiles",
Expand Down Expand Up @@ -260,12 +337,10 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct
},
}

for _, tc := range tests {
for _, tc := range tests { //nolint:paralleltest // tests run sequentially for simplicity
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

profileFeature, engine := profilesTest(t)

tc.mock(profileFeature)
Expand All @@ -277,6 +352,7 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct
if tc.requestBody.ProfileName != "" {
reqBody, _ := json.Marshal(tc.requestBody)
req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
} else {
req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, http.NoBody)
}
Expand Down Expand Up @@ -309,3 +385,190 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct
})
}
}

func TestProfileValidation(t *testing.T) { //nolint:paralleltest // tests run sequentially for simplicity
tests := []struct {
name string
profile dto.Profile
expectedCode int
}{
{
name: "valid profile - CCM with CIRA",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "ccmactivate",
GenerateRandomPassword: true,
GenerateRandomMEBxPassword: true,
CIRAConfigName: stringPtr("cira-config"),
DHCPEnabled: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusCreated,
},
{
name: "valid profile - ACM with TLS",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "acmactivate",
GenerateRandomPassword: true,
MEBXPassword: "P@ssw0rd123",
GenerateRandomMEBxPassword: false,
TLSMode: 1,
TLSSigningAuthority: "SelfSigned",
DHCPEnabled: true,
UserConsent: "KVM",
TenantID: "tenant1",
},
expectedCode: http.StatusCreated,
},
{
name: "invalid - both CIRA and TLS",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "ccmactivate",
GenerateRandomPassword: true,
GenerateRandomMEBxPassword: true,
CIRAConfigName: stringPtr("cira-config"),
TLSMode: 1,
DHCPEnabled: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
{
name: "invalid - wifi configs without DHCP",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "ccmactivate",
GenerateRandomPassword: true,
GenerateRandomMEBxPassword: true,
DHCPEnabled: false,
WiFiConfigs: []dto.ProfileWiFiConfigs{
{ProfileName: "wifi1", Priority: 1},
},
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
{
name: "invalid - password set with genRandom true",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "ccmactivate",
AMTPassword: "P@ssw0rd123",
GenerateRandomPassword: true,
GenerateRandomMEBxPassword: true,
DHCPEnabled: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
{
name: "invalid - invalid activation",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "invalidactivation",
GenerateRandomPassword: true,
GenerateRandomMEBxPassword: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
{
name: "invalid - invalid TLS signing authority",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "acmactivate",
GenerateRandomPassword: true,
GenerateRandomMEBxPassword: true,
TLSMode: 1,
TLSSigningAuthority: "InvalidAuthority",
DHCPEnabled: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
{
name: "invalid - TLS mode out of range",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "acmactivate",
GenerateRandomPassword: true,
GenerateRandomMEBxPassword: true,
TLSMode: 5,
TLSSigningAuthority: "SelfSigned",
DHCPEnabled: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
{
name: "invalid - password too short",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "acmactivate",
AMTPassword: "short",
GenerateRandomPassword: false,
MEBXPassword: "P@ssw0rd123",
GenerateRandomMEBxPassword: false,
DHCPEnabled: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
{
name: "invalid - password missing special character",
profile: dto.Profile{
ProfileName: "test-profile",
Activation: "acmactivate",
AMTPassword: "Password123",
GenerateRandomPassword: false,
MEBXPassword: "P@ssw0rd123",
GenerateRandomMEBxPassword: false,
DHCPEnabled: true,
UserConsent: "All",
TenantID: "tenant1",
},
expectedCode: http.StatusBadRequest,
},
}

for _, tc := range tests { //nolint:paralleltest // tests run sequentially for simplicity
tc := tc

t.Run(tc.name, func(t *testing.T) {
profileFeature, engine := profilesTest(t)

if tc.expectedCode == http.StatusCreated {
profileFeature.EXPECT().Insert(context.Background(), &tc.profile).Return(&tc.profile, nil)
}

reqBody, _ := json.Marshal(tc.profile)
req, err := http.NewRequestWithContext(
context.Background(),
http.MethodPost,
"/api/v1/admin/profiles",
bytes.NewBuffer(reqBody),
)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
engine.ServeHTTP(w, req)

require.Equal(t, tc.expectedCode, w.Code)
})
}
}

func stringPtr(s string) *string {
return &s
}
14 changes: 13 additions & 1 deletion internal/entity/dto/v1/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Profile struct {
DHCPEnabled bool `json:"dhcpEnabled" example:"true"`
IPSyncEnabled bool `json:"ipSyncEnabled" example:"true"`
LocalWiFiSyncEnabled bool `json:"localWifiSyncEnabled" example:"true"`
WiFiConfigs []ProfileWiFiConfigs `json:"wifiConfigs,omitempty" binding:"excluded_if=DHCPEnabled false,dive"`
WiFiConfigs []ProfileWiFiConfigs `json:"wifiConfigs,omitempty" binding:"wifidhcp,dive"`
TenantID string `json:"tenantId" example:"abc123"`
TLSMode int `json:"tlsMode,omitempty" binding:"omitempty,min=1,max=4,ciraortls" example:"1"`
TLSCerts *TLSCerts `json:"tlsCerts,omitempty"`
Expand Down Expand Up @@ -66,6 +66,18 @@ var ValidateUserConsent validator.Func = func(fl validator.FieldLevel) bool {
return userConsent == "none" || userConsent == "kvm" || userConsent == "all"
}

var ValidateWiFiDHCP validator.Func = func(fl validator.FieldLevel) bool {
dhcpEnabled := fl.Parent().FieldByName("DHCPEnabled").Bool()
wifiConfigs := fl.Field()

// If WiFiConfigs has items and DHCP is disabled, fail validation
if wifiConfigs.Len() > 0 && !dhcpEnabled {
return false
}

return true
}

type ProfileCountResponse struct {
Count int `json:"totalCount"`
Data []Profile `json:"data"`
Expand Down
Loading