Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions pkg/blob/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,27 +590,45 @@ func (d *Driver) GetAuthEnv(ctx context.Context, volumeID, protocol string, attr
tenantID = d.cloud.TenantID
}

if clientID != "" {
if mountWithWIToken {
klog.V(2).Infof("clientID(%s) is specified, use workload identity for blobfuse auth", clientID)

workloadIdentityToken, err := parseServiceAccountToken(serviceAccountToken)
if err != nil {
return rgName, accountName, accountKey, containerName, authEnv, err
if mountWithWIToken {
if clientID == "" {
clientID = d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID
if clientID == "" {
return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("mountWithWorkloadIdentityToken is true but clientID is not specified")
}
azureOAuthTokenFile := filepath.Join(defaultAzureOAuthTokenDir, clientID+accountName)
}
klog.V(2).Infof("mountWithWorkloadIdentityToken is specified, use workload identity auth for mount, clientID: %s, tenantID: %s", clientID, tenantID)

workloadIdentityToken, err := parseServiceAccountToken(serviceAccountToken)
if err != nil {
return rgName, accountName, accountKey, containerName, authEnv, err
}
tokenFileName := clientID + "-" + accountName
if !isValidTokenFileName(tokenFileName) {
return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("the generated token file name %s is invalid", tokenFileName)
}
azureOAuthTokenFile := filepath.Join(defaultAzureOAuthTokenDir, tokenFileName)
// check whether token value is the same as the one in the token file
existingToken, readErr := os.ReadFile(azureOAuthTokenFile)
if readErr == nil && string(existingToken) == workloadIdentityToken {
klog.V(4).Infof("the existing workload identity token file %s is up-to-date, no need to rewrite", azureOAuthTokenFile)
} else {
// write the token to a file
if err := os.WriteFile(azureOAuthTokenFile, []byte(workloadIdentityToken), 0600); err != nil {
return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("failed to write workload identity token file %s: %v", azureOAuthTokenFile, err)
}
}

authEnv = append(authEnv, "AZURE_STORAGE_SPN_CLIENT_ID="+clientID)
if tenantID != "" {
authEnv = append(authEnv, "AZURE_STORAGE_SPN_TENANT_ID="+tenantID)
}
authEnv = append(authEnv, "AZURE_OAUTH_TOKEN_FILE="+azureOAuthTokenFile)
klog.V(2).Infof("workload identity auth: %v", authEnv)
return rgName, accountName, accountKey, containerName, authEnv, err
authEnv = append(authEnv, "AZURE_STORAGE_SPN_CLIENT_ID="+clientID)
if tenantID != "" {
authEnv = append(authEnv, "AZURE_STORAGE_SPN_TENANT_ID="+tenantID)
}
authEnv = append(authEnv, "AZURE_OAUTH_TOKEN_FILE="+azureOAuthTokenFile)
klog.V(2).Infof("workload identity auth: %v", authEnv)
return rgName, accountName, accountKey, containerName, authEnv, err
}

if clientID != "" {
klog.V(2).Infof("clientID(%s) is specified, use service account token to get account key", clientID)
if subsID == "" {
subsID = d.cloud.SubscriptionID
Expand Down Expand Up @@ -1244,3 +1262,20 @@ func parseServiceAccountToken(tokenStr string) (string, error) {
}
return token.APIAzureADTokenExchange.Token, nil
}

// isValidTokenFileName checks if the token file name is valid
// fileName should only contain alphanumeric characters, hyphens
func isValidTokenFileName(fileName string) bool {
if fileName == "" {
return false
}
for _, c := range fileName {
if !(('a' <= c && c <= 'z') ||
('A' <= c && c <= 'Z') ||
('0' <= c && c <= '9') ||
(c == '-')) {
return false
}
}
return true
}
62 changes: 62 additions & 0 deletions pkg/blob/blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2132,3 +2132,65 @@ func TestIsSupportedPublicNetworkAccess(t *testing.T) {
}
}
}

func TestIsValidTokenFileName(t *testing.T) {
testCases := []struct {
name string
fileName string
expected bool
}{
{
name: "valid lowercase",
fileName: "token",
expected: true,
},
{
name: "valid uppercase",
fileName: "TOKEN",
expected: true,
},
{
name: "valid mixed alphanumeric with hyphen",
fileName: "Token-123",
expected: true,
},
{
name: "valid mixed alphanumeric with hyphen#2",
fileName: "0ab48765-efce-4799-8a9c-c3e1de2ee42eg",
expected: true,
},
{
name: "empty string",
fileName: "",
expected: false,
},
{
name: "contains underscore",
fileName: "token_file",
expected: false,
},
{
name: "contains dot",
fileName: "token.file",
expected: false,
},
{
name: "contains space",
fileName: "token file",
expected: false,
},
{
name: "contains slash",
fileName: "token/file",
expected: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if got := isValidTokenFileName(tc.fileName); got != tc.expected {
t.Fatalf("isValidTokenFileName(%q) = %t, want %t", tc.fileName, got, tc.expected)
}
})
}
}
12 changes: 10 additions & 2 deletions pkg/blob/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu
context := req.GetVolumeContext()
if context != nil {
// token request
if context[serviceAccountTokenField] != "" && getValueInMap(context, clientIDField) != "" {
if context[serviceAccountTokenField] != "" && useWorkloadIdentity(context) {
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, clientID: %s", volumeID, target, getValueInMap(context, clientIDField))
_, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{
StagingTargetPath: target,
Expand Down Expand Up @@ -261,7 +261,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe
attrib := req.GetVolumeContext()
secrets := req.GetSecrets()

if getValueInMap(attrib, clientIDField) != "" && attrib[serviceAccountTokenField] == "" {
if useWorkloadIdentity(attrib) && attrib[serviceAccountTokenField] == "" {
klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID %s is provided but service account token is empty", volumeID, getValueInMap(attrib, clientIDField))
return &csi.NodeStageVolumeResponse{}, nil
}
Expand Down Expand Up @@ -733,3 +733,11 @@ func checkGidPresentInMountFlags(mountFlags []string) bool {
}
return false
}

// useWorkloadIdentity checks whether workload identity is used based on the presence of clientID or mountWithWIToken in volume attributes
func useWorkloadIdentity(attrib map[string]string) bool {
if getValueInMap(attrib, clientIDField) != "" || getValueInMap(attrib, mountWithWITokenField) == trueValue {
return true
}
return false
}
43 changes: 43 additions & 0 deletions pkg/blob/nodeserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1192,3 +1192,46 @@ func TestCheckGidPresentInMountFlags(t *testing.T) {
}
}
}

func TestUseWorkloadIdentity(t *testing.T) {
tests := []struct {
name string
attrib map[string]string
want bool
}{
{
name: "clientID present",
attrib: map[string]string{
clientIDField: "client-id",
},
want: true,
},
{
name: "mountWithWIToken true",
attrib: map[string]string{
mountWithWITokenField: trueValue,
},
want: true,
},
{
name: "mountWithWIToken false",
attrib: map[string]string{
mountWithWITokenField: "false",
},
want: false,
},
{
name: "no workload identity fields",
attrib: map[string]string{},
want: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := useWorkloadIdentity(tt.attrib); got != tt.want {
t.Errorf("useWorkloadIdentity() = %v, want %v", got, tt.want)
}
})
}
}
Loading