Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxdiffusion/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import importlib
import os

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from huggingface_hub.constants import HF_HOME, HUGGINGFACE_HUB_CACHE
from packaging import version

from .import_utils import is_peft_available
Expand All @@ -34,7 +34,7 @@
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]

# Below should be `True` if the current version of `peft` and `transformers` are compatible with
Expand Down
48 changes: 15 additions & 33 deletions src/maxdiffusion/utils/dynamic_modules_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@
import re
import shutil
import sys
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union
from urllib import request

from huggingface_hub import HfFolder, hf_hub_download, model_info
import huggingface_hub
from huggingface_hub import get_token, hf_hub_download, model_info
from packaging import version

cached_download = None

from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging

Expand All @@ -42,24 +40,6 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

# https://github.com/huggingface/huggingface_hub/releases/tag/v0.26.0
# `cached_download(), url_to_filename(), filename_to_url() methods are now completely removed.
# From now on, you will have to use hf_hub_download() to benefit from the new cache layout.`
if hasattr(huggingface_hub, "__version__"):
current_version = version.parse(huggingface_hub.__version__)
target_version = version.parse("0.26.0")

if current_version < target_version:
try:
from huggingface_hub import cached_download

except ImportError:
logger.error(
f"huggingface_hub version {current_version} is below 0.26.0, but 'cached_download' could not be imported. It might have been removed or deprecated in this version as well."
)
else:
logger.error("Could not determine huggingface_hub version. Unable to conditionally import 'cached_download'.")


def get_diffusers_versions():
url = "https://pypi.org/pypi/diffusers/json"
Expand Down Expand Up @@ -303,15 +283,17 @@ def get_cached_module_file(
# community pipeline on GitHub
github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path)
try:
resolved_module_file = cached_download(
github_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=False,
)
# Given that cached download has been removed, try using just urlopen
fd, resolved_module_file = tempfile.mkstemp(dir=cache_dir)
try:
response = request.urlopen(github_url)
with os.fdopen(fd, "wb") as f:
f.write(response.read())
except Exception:
os.remove(resolved_module_file)
raise EnvironmentError(
f"Failed to download community pipeline from {github_url}. Please check if the url is correct."
)
submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
except EnvironmentError:
Expand All @@ -328,7 +310,7 @@ def get_cached_module_file(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=use_auth_token,
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
except EnvironmentError:
Expand Down Expand Up @@ -356,7 +338,7 @@ def get_cached_module_file(
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
token = HfFolder.get_token()
token = get_token()
else:
token = None

Expand Down
6 changes: 3 additions & 3 deletions src/maxdiffusion/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Dict, Optional, Union
from uuid import uuid4

from huggingface_hub import HfFolder, ModelCard, ModelCardData, create_repo, hf_hub_download, upload_folder, whoami
from huggingface_hub import ModelCard, ModelCardData, create_repo, get_token, hf_hub_download, upload_folder, whoami
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import (
EntryNotFoundError,
Expand Down Expand Up @@ -92,7 +92,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:

def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
token = get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
Expand Down Expand Up @@ -288,7 +288,7 @@ def _get_model_file(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
Expand Down
Loading