Pytorch Hub es un repositorio de modelos previamente entrenado diseñado para facilitar la reproducibilidad de la investigación.

Publicar modelos

Pytorch Hub admite la publicación de modelos previamente entrenados (definiciones de modelo y pesos previamente entrenados) en un repositorio de github agregando un simple hubconf.py expediente;

hubconf.py puede tener varios puntos de entrada. Cada punto de entrada se define como una función de Python (ejemplo: un modelo previamente entrenado que desea publicar).

defentrypoint_name(*args,**kwargs):# args & kwargs are optional, for models which take positional/keyword arguments....

¿Cómo implementar un punto de entrada?

Aquí hay un fragmento de código que especifica un punto de entrada para resnet18 modelo si ampliamos la implementación en pytorch/vision/hubconf.py. En la mayoría de los casos, la importación de la función correcta en hubconf.py es suficiente. Aquí solo queremos usar la versión expandida como ejemplo para mostrar cómo funciona. Puedes ver el guión completo en pytorch / repositorio de visión

dependencies =['torch']from torchvision.models.resnet import resnet18 as _resnet18

# resnet18 is the name of entrypointdefresnet18(pretrained=False,**kwargs):""" # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """# Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained,**kwargs)return model
  • dependencies variable es una lista de los nombres de paquetes necesarios para carga el modelo. Tenga en cuenta que esto puede ser ligeramente diferente de las dependencias necesarias para entrenar un modelo.
  • args y kwargs se pasan a la función real invocable.
  • Docstring de la función funciona como un mensaje de ayuda. Explica qué hace el modelo y cuáles son los argumentos posicionales / de palabras clave permitidos. Es muy recomendable agregar algunos ejemplos aquí.
  • La función de punto de entrada puede devolver un modelo (nn.module) o herramientas auxiliares para hacer que el flujo de trabajo del usuario sea más fluido, por ejemplo, tokenizadores.
  • Los invocables con el prefijo de subrayado se consideran funciones auxiliares que no se mostrarán en torch.hub.list().
  • Los pesos previamente entrenados se pueden almacenar localmente en el repositorio de github o se pueden cargar mediante torch.hub.load_state_dict_from_url(). Si tiene menos de 2 GB, se recomienda conectarlo a un lanzamiento del proyecto y use la URL de la versión. En el ejemplo anterior torchvision.models.resnet.resnet18 manejas pretrained, alternativamente, puede poner la siguiente lógica en la definición del punto de entrada.
if pretrained:# For checkpoint saved in local github repo, e.g. =weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname,<RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)# For checkpoint saved elsewhere
    checkpoint ='https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

Noticia importante

  • Los modelos publicados deben estar al menos en una rama / etiqueta. No puede ser una confirmación aleatoria.

Carga de modelos desde Hub

Pytorch Hub proporciona API convenientes para explorar todos los modelos disponibles en el hub a través de torch.hub.list(), muestre cadenas de documentos y ejemplos a través de torch.hub.help() y cargue los modelos previamente entrenados usando torch.hub.load().

torch.hub.list(github, force_reload=False)[source]

Enumere todos los puntos de entrada disponibles en github hubconf.

Parámetros
  • github (string) – a string con formato “repo_owner / repo_name[:tag_name]”Con una etiqueta / rama opcional. La rama predeterminada es master si no se especifica. Ejemplo: ‘pytorch / vision[:hub]’
  • force_reload (bool, Opcional): Si descartar la caché existente y forzar una nueva descarga. El valor predeterminado es False.
Devoluciones

una lista de nombres de puntos de entrada disponibles

Tipo de retorno

puntos de entrada

Ejemplo

>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False)[source]

Mostrar la cadena de documentos del punto de entrada model.

Parámetros
  • github (string) – a string con formato con una etiqueta / rama opcional. La rama predeterminada es master si no se especifica. Ejemplo: ‘pytorch / vision[:hub]’
  • modelo (string) – a string del nombre del punto de entrada definido en el repositorio hubconf.py
  • force_reload (bool, Opcional): Si descartar la caché existente y forzar una nueva descarga. El valor predeterminado es False.

Ejemplo

>>>print(torch.hub.help('pytorch/vision','resnet18', force_reload=True))
torch.hub.load(repo_or_dir, model, *args, **kwargs)[source]

Cargue un modelo desde un repositorio de github o un directorio local.

Nota: La carga de un modelo es el caso de uso típico, pero también se puede utilizar para cargar otros objetos como tokenizadores, funciones de pérdida, etc.

Si source es 'github', repo_or_dir se espera que sea de la forma repo_owner/repo_name[:tag_name] con una etiqueta / rama opcional.

Si source es 'local', repo_or_dir se espera que sea una ruta a un directorio local.

Parámetros
  • repo_or_dir (string) – nombre del repositorio (repo_owner/repo_name[:tag_name]), si source = 'github'; o una ruta a un directorio local, si source = 'local'.
  • modelo (string) – el nombre de un invocable (punto de entrada) definido en el repositorio / directorio hubconf.py.
  • * argumentos (Opcional) – los argumentos correspondientes para invocables model.
  • fuente (string, Opcional) – 'github' | 'local'. Especifica cómo repo_or_dir debe ser interpretado. El valor predeterminado es 'github'.
  • force_reload (bool, Opcional): Si se debe forzar una nueva descarga del repositorio de github incondicionalmente. No tiene ningún efecto si source = 'local'. El valor predeterminado es False.
  • verboso (bool, Opcional) – Si False, silencia los mensajes sobre el acceso a cachés locales. Tenga en cuenta que el mensaje sobre la primera descarga no se puede silenciar. No tiene ningún efecto si source = 'local'. El valor predeterminado es True.
  • ** kwargs (Opcional) – los kwargs correspondientes para invocables model.
Devoluciones

La salida de la model invocable cuando se llama con el dado *args y **kwargs.

Ejemplo

>>># from a github repo>>> repo ='pytorch/vision'>>> model = torch.hub.load(repo,'resnet50', pretrained=True)>>># from a local directory>>> path ='/some/local/path/pytorch/vision'>>> model = torch.hub.load(path,'resnet50', pretrained=True)
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]

Descargue el objeto en la URL dada a una ruta local.

Parámetros
  • url (string) – URL del objeto para descargar
  • dst (string) – Ruta completa donde se guardará el objeto, p. Ej. /tmp/temporary_file
  • hash_prefix (string, Opcional) – Si no es Ninguno, el archivo descargado SHA256 debe comenzar con hash_prefix. Predeterminado: Ninguno
  • Progreso (bool, Opcional) – si mostrar o no una barra de progreso en stderr Predeterminado: Verdadero

Ejemplo

>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth','/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None)[source]

Carga el objeto serializado Torch en la URL dada.

Si el archivo descargado es un archivo zip, se descomprimirá automáticamente.

Si el objeto ya está presente en model_dir, se deserializa y se devuelve. El valor predeterminado de model_dir es /checkpoints dónde hub_dir es el directorio devuelto por get_dir().

Parámetros
  • url (string) – URL del objeto para descargar
  • model_dir (string, Opcional) – directorio en el que guardar el objeto
  • map_location (Opcional) – una función o un dictado que especifica cómo reasignar ubicaciones de almacenamiento (ver torch.load)
  • Progreso (bool, Opcional): Si mostrar o no una barra de progreso en stderr. Predeterminado: Verdadero
  • check_hash (bool, Opcional) – Si es True, la parte del nombre de archivo de la URL debe seguir la convención de nomenclatura filename-.ext dónde son los primeros ocho o más dígitos del hash SHA256 del contenido del archivo. El hash se utiliza para garantizar nombres únicos y verificar el contenido del archivo. Predeterminado: falso
  • nombre del archivo (string, Opcional): Nombre del archivo descargado. Nombre de archivo de url se utilizará si no se establece.

Ejemplo

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

Ejecutando un modelo cargado:

Tenga en cuenta que *args y **kwargs en torch.hub.load() se utilizan para instanciar un modelo. Después de haber cargado un modelo, ¿cómo puede averiguar qué puede hacer con el modelo? Un flujo de trabajo sugerido es

  • dir(model) para ver todos los métodos disponibles del modelo.
  • help(model.foo) para comprobar que argumentos model.foo toma para correr

Para ayudar a los usuarios a explorar sin consultar la documentación de un lado a otro, recomendamos encarecidamente a los propietarios de repositorios que los mensajes de ayuda de funciones sean claros y concisos. También es útil incluir un ejemplo de trabajo mínimo.

¿Dónde se guardan mis modelos descargados?

Las ubicaciones se utilizan en el orden de

  • Vocación hub.set_dir()
  • $TORCH_HOME/hub, si es variable de entorno TORCH_HOME Está establecido.
  • $XDG_CACHE_HOME/torch/hub, si es variable de entorno XDG_CACHE_HOME Está establecido.
  • ~/.cache/torch/hub
torch.hub.get_dir()[source]

Obtenga el directorio de caché de Torch Hub utilizado para almacenar modelos y pesos descargados.

Si set_dir() no se llama, la ruta predeterminada es $TORCH_HOME/hub donde variable de entorno $TORCH_HOME predeterminado a $XDG_CACHE_HOME/torch. $XDG_CACHE_HOME sigue la especificación X Design Group del diseño del sistema de archivos Linux, con un valor predeterminado ~/.cache si no se establece la variable de entorno.

torch.hub.set_dir(d)[source]

Opcionalmente, configure el directorio de Torch Hub utilizado para guardar modelos y pesos descargados.

Parámetros

D (string): Ruta a una carpeta local para guardar modelos y pesos descargados.

Lógica de almacenamiento en caché

De forma predeterminada, no limpiamos los archivos después de cargarlos. Hub usa la caché de forma predeterminada si ya existe en el directorio devuelto por get_dir().

Los usuarios pueden forzar una recarga llamando hub.load(..., force_reload=True). Esto eliminará la carpeta github existente y los pesos descargados, reinicializará una nueva descarga. Esto es útil cuando las actualizaciones se publican en la misma sucursal, los usuarios pueden mantenerse al día con la última versión.

Limitaciones conocidas:

Torch hub funciona importando el paquete como si estuviera instalado. Hay algunos efectos secundarios introducidos al importar en Python. Por ejemplo, puede ver nuevos elementos en las cachés de Python sys.modules y sys.path_importer_cache que es el comportamiento normal de Python.

Una limitación conocida que vale la pena mencionar aquí es el usuario NO PODER cargar dos ramas diferentes del mismo repositorio en el mismo proceso de Python. Es como instalar dos paquetes con el mismo nombre en Python, lo cual no es bueno. Cache podría unirse a la fiesta y darte sorpresas si realmente lo intentas. Por supuesto, está bien cargarlos en procesos separados.