Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 248 additions & 38 deletions cli/azd/extensions/azure.ai.agents/internal/cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -875,6 +877,7 @@ func (a *InitAction) downloadAgentYaml(
var urlInfo *GitHubUrlInfo
var ghCli *github.Cli
var console input.Console
var useGhCli bool = false

// Check if manifestPointer is a local file path or a URI
if a.isLocalFilePath(manifestPointer) {
Expand Down Expand Up @@ -927,6 +930,15 @@ func (a *InitAction) downloadAgentYaml(
}
} else if a.isGitHubUrl(manifestPointer) {
// Handle GitHub URLs using downloadGithubManifest
// manifestPointer validation:
// - accepts only URLs with the following format:
// - https://raw.<hostname>/<owner>/<repo>/refs/heads/<branch>/<path>/<file>.json
// - This url comes from a user clicking the `raw` button on a file in a GitHub repository (web view).
// - https://<hostname>/<owner>/<repo>/blob/<branch>/<path>/<file>.json
// - This url comes from a user browsing GitHub repository and copy-pasting the url from the browser.
// - https://api.<hostname>/repos/<owner>/<repo>/contents/<path>/<file>.json
// - This url comes from users familiar with the GitHub API. Usually for programmatic registration of templates.

fmt.Printf("Downloading agent.yaml from GitHub: %s\n", manifestPointer)
isGitHubUrl = true

Expand Down Expand Up @@ -954,20 +966,48 @@ func (a *InitAction) downloadAgentYaml(
return nil, "", fmt.Errorf("ensuring gh is installed: %w", err)
}

urlInfo, err = a.parseGitHubUrl(ctx, manifestPointer)
if err != nil {
return nil, "", err
var contentStr string
// First try naive parsing assuming branch is a single word. This allows users to not have to authenticate
// with gh CLI for public repositories.
urlInfo = a.parseGitHubUrlNaive(manifestPointer)
if urlInfo != nil {
// Construct raw GitHub URL to fetch file directly
rawUrl := fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s", urlInfo.RepoSlug, urlInfo.Branch, urlInfo.FilePath)
fmt.Printf("Attempting to download manifest from: %s\n", rawUrl)

resp, err := http.Get(rawUrl)
if err == nil && resp.StatusCode == http.StatusOK {
defer resp.Body.Close()
bodyBytes, readErr := io.ReadAll(resp.Body)
if readErr == nil {
contentStr = string(bodyBytes)
fmt.Printf("Downloaded manifest from branch: %s\n", urlInfo.Branch)
}
}
if contentStr == "" {
fmt.Printf("Warning: naive GitHub URL parsing failed to download manifest\n")
fmt.Println("Proceeding with full parsing and download logic...")
}
}

apiPath := fmt.Sprintf("/repos/%s/contents/%s", urlInfo.RepoSlug, urlInfo.FilePath)
if urlInfo.Branch != "" {
fmt.Printf("Downloaded manifest from branch: %s\n", urlInfo.Branch)
apiPath += fmt.Sprintf("?ref=%s", urlInfo.Branch)
}
if contentStr == "" {
// Fall back to complex parsing via azd GitHub CLI handling
useGhCli = true
urlInfo, err = a.parseGitHubUrl(ctx, manifestPointer)
if err != nil {
return nil, "", err
}

contentStr, err := downloadGithubManifest(ctx, urlInfo, apiPath, ghCli, console)
if err != nil {
return nil, "", fmt.Errorf("downloading from GitHub: %w", err)
apiPath := fmt.Sprintf("/repos/%s/contents/%s", urlInfo.RepoSlug, urlInfo.FilePath)
if urlInfo.Branch != "" {
fmt.Printf("Downloaded manifest from branch: %s\n", urlInfo.Branch)
apiPath += fmt.Sprintf("?ref=%s", urlInfo.Branch)
}

contentStr, err = downloadGithubManifest(ctx, urlInfo, apiPath, ghCli)
if err != nil {
return nil, "", fmt.Errorf("downloading from GitHub: %w", err)
}
}

content = []byte(contentStr)
Expand Down Expand Up @@ -1116,7 +1156,7 @@ func (a *InitAction) downloadAgentYaml(
if isHostedContainer {
// For container agents, download the entire parent directory
fmt.Println("Downloading full directory for container agent")
err := downloadParentDirectory(ctx, urlInfo, targetDir, ghCli, console)
err := downloadParentDirectory(ctx, urlInfo, targetDir, ghCli, console, useGhCli)
if err != nil {
return nil, "", fmt.Errorf("downloading parent directory: %w", err)
}
Expand Down Expand Up @@ -1341,29 +1381,9 @@ func (a *InitAction) populateContainerSettings(ctx context.Context) (*project.Co
}

func downloadGithubManifest(
ctx context.Context, urlInfo *GitHubUrlInfo, apiPath string, ghCli *github.Cli, console input.Console) (string, error) {
// manifestPointer validation:
// - accepts only URLs with the following format:
// - https://raw.<hostname>/<owner>/<repo>/refs/heads/<branch>/<path>/<file>.json
// - This url comes from a user clicking the `raw` button on a file in a GitHub repository (web view).
// - https://<hostname>/<owner>/<repo>/blob/<branch>/<path>/<file>.json
// - This url comes from a user browsing GitHub repository and copy-pasting the url from the browser.
// - https://api.<hostname>/repos/<owner>/<repo>/contents/<path>/<file>.json
// - This url comes from users familiar with the GitHub API. Usually for programmatic registration of templates.

authResult, err := ghCli.GetAuthStatus(ctx, urlInfo.Hostname)
if err != nil {
return "", fmt.Errorf("failed to get auth status: %w", err)
}
if !authResult.LoggedIn {
// ensure no spinner is shown when logging in, as this is interactive operation
console.StopSpinner(ctx, "", input.Step)
err := ghCli.Login(ctx, urlInfo.Hostname)
if err != nil {
return "", fmt.Errorf("failed to login: %w", err)
}
console.ShowSpinner(ctx, "Validating template source", input.Step)
}
ctx context.Context, urlInfo *GitHubUrlInfo, apiPath string, ghCli *github.Cli) (string, error) {
// This method assumes that either the repo is public, or the user has already been prompted to log in to the github cli
// through our use of the underlying azd logic.

content, err := ghCli.ApiCall(ctx, urlInfo.Hostname, apiPath, github.ApiCallOptions{
Headers: []string{"Accept: application/vnd.github.v3.raw"},
Expand All @@ -1375,6 +1395,99 @@ func downloadGithubManifest(
return content, nil
}

// parseGitHubUrlNaive attempts to parse a GitHub URL assuming a simple single-word branch name.
// Returns nil if the URL doesn't match the expected pattern.
// Expected formats:
// - https://github.com/{owner}/{repo}/blob/{branch}/{path}
// - https://raw.githubusercontent.com/{owner}/{repo}/refs/heads/{branch}/{path}
func (a *InitAction) parseGitHubUrlNaive(manifestPointer string) *GitHubUrlInfo {
// Parse URL to properly handle query parameters and fragments
parsedURL, err := url.Parse(manifestPointer)
if err != nil {
return nil
}

// Try parsing github.com/blob format: https://github.com/{owner}/{repo}/blob/{branch}/{path}
if parsedURL.Host == "github.com" && strings.Contains(parsedURL.Path, "/blob/") {
hostname := "github.com"

// Split by /blob/
parts := strings.SplitN(parsedURL.Path, "/blob/", 2)
if len(parts) != 2 {
return nil
}

// Extract repo slug (owner/repo) from the first part
repoPath := strings.TrimPrefix(parts[0], "/")
repoSlug := repoPath

// The second part is {branch}/{file-path}
branchAndPath := parts[1]
slashIndex := strings.Index(branchAndPath, "/")
if slashIndex == -1 {
return nil
}

branch := branchAndPath[:slashIndex]
filePath := branchAndPath[slashIndex+1:]

// Only use naive parsing if branch looks like a simple single word (no slashes)
if strings.Contains(branch, "/") {
return nil
}

return &GitHubUrlInfo{
RepoSlug: repoSlug,
Branch: branch,
FilePath: filePath,
Hostname: hostname,
}
}

// Try parsing raw.githubusercontent.com format: https://raw.githubusercontent.com/{owner}/{repo}/refs/heads/{branch}/{path}
if parsedURL.Host == "raw.githubusercontent.com" {
hostname := "github.com" // API calls still use github.com

// Remove leading slash from path
pathPart := strings.TrimPrefix(parsedURL.Path, "/")

// Split path: {owner}/{repo}/refs/heads/{branch}/{file-path}
parts := strings.SplitN(pathPart, "/", 3) // owner, repo, rest
if len(parts) < 3 {
return nil
}

repoSlug := parts[0] + "/" + parts[1]
rest := parts[2]

// Check for refs/heads/ prefix
if strings.HasPrefix(rest, "refs/heads/") {
rest = strings.TrimPrefix(rest, "refs/heads/")
slashIndex := strings.Index(rest, "/")
if slashIndex == -1 {
return nil
}

branch := rest[:slashIndex]
filePath := rest[slashIndex+1:]

// Only use naive parsing if branch looks like a simple single word
if strings.Contains(branch, "/") {
return nil
}

return &GitHubUrlInfo{
RepoSlug: repoSlug,
Branch: branch,
FilePath: filePath,
Hostname: hostname,
}
}
}

return nil
}

// parseGitHubUrl extracts repository information from various GitHub URL formats using extension framework
func (a *InitAction) parseGitHubUrl(ctx context.Context, manifestPointer string) (*GitHubUrlInfo, error) {
urlInfo, err := a.azdClient.Project().ParseGitHubUrl(ctx, &azdext.ParseGitHubUrlRequest{
Expand All @@ -1393,7 +1506,7 @@ func (a *InitAction) parseGitHubUrl(ctx context.Context, manifestPointer string)
}

func downloadParentDirectory(
ctx context.Context, urlInfo *GitHubUrlInfo, targetDir string, ghCli *github.Cli, console input.Console) error {
ctx context.Context, urlInfo *GitHubUrlInfo, targetDir string, ghCli *github.Cli, console input.Console, useGhCli bool) error {

// Get parent directory by removing the filename from the file path
pathParts := strings.Split(urlInfo.FilePath, "/")
Expand All @@ -1406,8 +1519,14 @@ func downloadParentDirectory(
fmt.Printf("Downloading parent directory '%s' from repository '%s', branch '%s'\n", parentDirPath, urlInfo.RepoSlug, urlInfo.Branch)

// Download directory contents
if err := downloadDirectoryContents(ctx, urlInfo.Hostname, urlInfo.RepoSlug, parentDirPath, urlInfo.Branch, targetDir, ghCli, console); err != nil {
return fmt.Errorf("failed to download directory contents: %w", err)
if useGhCli {
if err := downloadDirectoryContents(ctx, urlInfo.Hostname, urlInfo.RepoSlug, parentDirPath, urlInfo.Branch, targetDir, ghCli, console); err != nil {
return fmt.Errorf("failed to download directory contents: %w", err)
}
} else {
if err := downloadDirectoryContentsWithoutGhCli(ctx, urlInfo.RepoSlug, parentDirPath, urlInfo.Branch, targetDir); err != nil {
return fmt.Errorf("failed to download directory contents: %w", err)
}
}

fmt.Printf("Successfully downloaded parent directory to: %s\n", targetDir)
Expand Down Expand Up @@ -1484,6 +1603,97 @@ func downloadDirectoryContents(
return nil
}

func downloadDirectoryContentsWithoutGhCli(
ctx context.Context, repoSlug string, dirPath string, branch string, localPath string) error {

// Get directory contents using GitHub API directly
apiUrl := fmt.Sprintf("https://api.github.com/repos/%s/contents/%s", repoSlug, dirPath)
if branch != "" {
apiUrl += fmt.Sprintf("?ref=%s", branch)
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiUrl, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to get directory contents: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to get directory contents: status %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read directory contents response: %w", err)
}

// Parse the directory contents JSON
var dirContents []map[string]interface{}
if err := json.Unmarshal(body, &dirContents); err != nil {
return fmt.Errorf("failed to parse directory contents JSON: %w", err)
}

// Download each file and subdirectory
for _, item := range dirContents {
name, ok := item["name"].(string)
if !ok {
continue
}

itemType, ok := item["type"].(string)
if !ok {
continue
}

itemPath := fmt.Sprintf("%s/%s", dirPath, name)
itemLocalPath := filepath.Join(localPath, name)

if itemType == "file" {
// Download file using raw.githubusercontent.com
fmt.Printf("Downloading file: %s\n", itemPath)
rawUrl := fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s", repoSlug, branch, itemPath)

fileResp, err := http.Get(rawUrl)
if err != nil {
return fmt.Errorf("failed to download file %s: %w", itemPath, err)
}
defer fileResp.Body.Close()

if fileResp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download file %s: status %d", itemPath, fileResp.StatusCode)
}

fileContent, err := io.ReadAll(fileResp.Body)
if err != nil {
return fmt.Errorf("failed to read file content %s: %w", itemPath, err)
}

if err := os.WriteFile(itemLocalPath, fileContent, 0644); err != nil {
return fmt.Errorf("failed to write file %s: %w", itemLocalPath, err)
}
} else if itemType == "dir" {
// Recursively download subdirectory
fmt.Printf("Downloading directory: %s\n", itemPath)
if err := os.MkdirAll(itemLocalPath, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", itemLocalPath, err)
}

// Recursively download directory contents
if err := downloadDirectoryContentsWithoutGhCli(ctx, repoSlug, itemPath, branch, itemLocalPath); err != nil {
return fmt.Errorf("failed to download subdirectory %s: %w", itemPath, err)
}
}
}

return nil
}

// func (a *InitAction) validateResources(ctx context.Context, agentYaml map[string]interface{}) error {
// fmt.Println("Reading model name from agent.yaml...")

Expand Down
Loading
Loading