Basta ya de indagar en otras webs porque has llegado al espacio justo, tenemos la respuesta que necesitas recibir y sin problemas.
Pytorch Hub es un repositorio de modelos previamente entrenado diseñado para facilitar la reproducibilidad de la investigación.
Publicar modelos
Pytorch Hub admite la publicación de modelos previamente entrenados (definiciones de modelo y pesos previamente entrenados) en un repositorio de github agregando un simple hubconf.py
expediente;
hubconf.py
puede tener varios puntos de entrada. Cada punto de entrada se define como una función de Python (ejemplo: un modelo previamente entrenado que desea publicar).
defentrypoint_name(*args,**kwargs):# args & kwargs are optional, for models which take positional/keyword arguments....
¿Cómo implementar un punto de entrada?
Aquí hay un fragmento de código que especifica un punto de entrada para resnet18
modelo si ampliamos la implementación en pytorch/vision/hubconf.py
. En la mayoría de los casos, la importación de la función correcta en hubconf.py
es suficiente. Aquí solo queremos usar la versión expandida como ejemplo para mostrar cómo funciona. Puedes ver el guión completo en pytorch / repositorio de visión
dependencies =['torch']from torchvision.models.resnet import resnet18 as _resnet18 # resnet18 is the name of entrypointdefresnet18(pretrained=False,**kwargs):""" # This docstring shows up in hub.help() Resnet18 model pretrained (bool): kwargs, load pretrained weights into the model """# Call the model, load pretrained weights model = _resnet18(pretrained=pretrained,**kwargs)return model
dependencies
variable es una lista de los nombres de paquetes necesarios para carga el modelo. Tenga en cuenta que esto puede ser ligeramente diferente de las dependencias necesarias para entrenar un modelo.args
ykwargs
se pasan a la función real invocable.- Docstring de la función funciona como un mensaje de ayuda. Explica qué hace el modelo y cuáles son los argumentos posicionales / de palabras clave permitidos. Es muy recomendable agregar algunos ejemplos aquí.
- La función de punto de entrada puede devolver un modelo (nn.module) o herramientas auxiliares para hacer que el flujo de trabajo del usuario sea más fluido, por ejemplo, tokenizadores.
- Los invocables con el prefijo de subrayado se consideran funciones auxiliares que no se mostrarán en
torch.hub.list()
. - Los pesos previamente entrenados se pueden almacenar localmente en el repositorio de github o se pueden cargar mediante
torch.hub.load_state_dict_from_url()
. Si tiene menos de 2 GB, se recomienda conectarlo a un lanzamiento del proyecto y use la URL de la versión. En el ejemplo anteriortorchvision.models.resnet.resnet18
manejaspretrained
, alternativamente, puede poner la siguiente lógica en la definición del punto de entrada.
if pretrained:# For checkpoint saved in local github repo, e.g.=weights/save.pth dirname = os.path.dirname(__file__) checkpoint = os.path.join(dirname,<RELATIVE_PATH_TO_CHECKPOINT>) state_dict = torch.load(checkpoint) model.load_state_dict(state_dict)# For checkpoint saved elsewhere checkpoint ='https://download.pytorch.org/models/resnet18-5c106cde.pth' model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
Noticia importante
- Los modelos publicados deben estar al menos en una rama / etiqueta. No puede ser una confirmación aleatoria.
Carga de modelos desde Hub
Pytorch Hub proporciona API convenientes para explorar todos los modelos disponibles en el hub a través de torch.hub.list()
, muestre cadenas de documentos y ejemplos a través de torch.hub.help()
y cargue los modelos previamente entrenados usando torch.hub.load()
.
torch.hub.list(github, force_reload=False)
[source]-
Enumere todos los puntos de entrada disponibles en
github
hubconf.- Parámetros
-
- github (string) – a string con formato “repo_owner / repo_name[:tag_name]”Con una etiqueta / rama opcional. La rama predeterminada es
master
si no se especifica. Ejemplo: ‘pytorch / vision[:hub]’ - force_reload (bool, Opcional): Si descartar la caché existente y forzar una nueva descarga. El valor predeterminado es
False
.
- github (string) – a string con formato “repo_owner / repo_name[:tag_name]”Con una etiqueta / rama opcional. La rama predeterminada es
- Devoluciones
-
una lista de nombres de puntos de entrada disponibles
- Tipo de retorno
-
puntos de entrada
Ejemplo
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False)
[source]-
Mostrar la cadena de documentos del punto de entrada
model
.- Parámetros
-
- github (string) – a string con formato
con una etiqueta / rama opcional. La rama predeterminada es master
si no se especifica. Ejemplo: ‘pytorch / vision[:hub]’ - modelo (string) – a string del nombre del punto de entrada definido en el repositorio hubconf.py
- force_reload (bool, Opcional): Si descartar la caché existente y forzar una nueva descarga. El valor predeterminado es
False
.
- github (string) – a string con formato
Ejemplo
>>>print(torch.hub.help('pytorch/vision','resnet18', force_reload=True))
torch.hub.load(repo_or_dir, model, *args, **kwargs)
[source]-
Cargue un modelo desde un repositorio de github o un directorio local.
Nota: La carga de un modelo es el caso de uso típico, pero también se puede utilizar para cargar otros objetos como tokenizadores, funciones de pérdida, etc.
Si
source
es'github'
,repo_or_dir
se espera que sea de la formarepo_owner/repo_name[:tag_name]
con una etiqueta / rama opcional.Si
source
es'local'
,repo_or_dir
se espera que sea una ruta a un directorio local.- Parámetros
-
- repo_or_dir (string) – nombre del repositorio (
repo_owner/repo_name[:tag_name]
), sisource = 'github'
; o una ruta a un directorio local, sisource = 'local'
. - modelo (string) – el nombre de un invocable (punto de entrada) definido en el repositorio / directorio
hubconf.py
. - * argumentos (Opcional) – los argumentos correspondientes para invocables
model
. - fuente (string, Opcional) –
'github'
|'local'
. Especifica cómorepo_or_dir
debe ser interpretado. El valor predeterminado es'github'
. - force_reload (bool, Opcional): Si se debe forzar una nueva descarga del repositorio de github incondicionalmente. No tiene ningún efecto si
source = 'local'
. El valor predeterminado esFalse
. - verboso (bool, Opcional) – Si
False
, silencia los mensajes sobre el acceso a cachés locales. Tenga en cuenta que el mensaje sobre la primera descarga no se puede silenciar. No tiene ningún efecto sisource = 'local'
. El valor predeterminado esTrue
. - ** kwargs (Opcional) – los kwargs correspondientes para invocables
model
.
- repo_or_dir (string) – nombre del repositorio (
- Devoluciones
-
La salida de la
model
invocable cuando se llama con el dado*args
y**kwargs
.
Ejemplo
>>># from a github repo>>> repo ='pytorch/vision'>>> model = torch.hub.load(repo,'resnet50', pretrained=True)>>># from a local directory>>> path ='/some/local/path/pytorch/vision'>>> model = torch.hub.load(path,'resnet50', pretrained=True)
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)
[source]-
Descargue el objeto en la URL dada a una ruta local.
- Parámetros
-
- url (string) – URL del objeto para descargar
- dst (string) – Ruta completa donde se guardará el objeto, p. Ej.
/tmp/temporary_file
- hash_prefix (string, Opcional) – Si no es Ninguno, el archivo descargado SHA256 debe comenzar con
hash_prefix
. Predeterminado: Ninguno - Progreso (bool, Opcional) – si mostrar o no una barra de progreso en stderr Predeterminado: Verdadero
Ejemplo
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth','/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None)
[source]-
Carga el objeto serializado Torch en la URL dada.
Si el archivo descargado es un archivo zip, se descomprimirá automáticamente.
Si el objeto ya está presente en
model_dir
, se deserializa y se devuelve. El valor predeterminado demodel_dir
es
dónde/checkpoints hub_dir
es el directorio devuelto porget_dir()
.- Parámetros
-
- url (string) – URL del objeto para descargar
- model_dir (string, Opcional) – directorio en el que guardar el objeto
- map_location (Opcional) – una función o un dictado que especifica cómo reasignar ubicaciones de almacenamiento (ver torch.load)
- Progreso (bool, Opcional): Si mostrar o no una barra de progreso en stderr. Predeterminado: Verdadero
- check_hash (bool, Opcional) – Si es True, la parte del nombre de archivo de la URL debe seguir la convención de nomenclatura
filename-
dónde.ext
son los primeros ocho o más dígitos del hash SHA256 del contenido del archivo. El hash se utiliza para garantizar nombres únicos y verificar el contenido del archivo. Predeterminado: falso - nombre del archivo (string, Opcional): Nombre del archivo descargado. Nombre de archivo de
url
se utilizará si no se establece.
Ejemplo
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
Ejecutando un modelo cargado:
Tenga en cuenta que *args
y **kwargs
en torch.hub.load()
se utilizan para instanciar un modelo. Después de haber cargado un modelo, ¿cómo puede averiguar qué puede hacer con el modelo? Un flujo de trabajo sugerido es
dir(model)
para ver todos los métodos disponibles del modelo.help(model.foo)
para comprobar que argumentosmodel.foo
toma para correr
Para ayudar a los usuarios a explorar sin consultar la documentación de un lado a otro, recomendamos encarecidamente a los propietarios de repositorios que los mensajes de ayuda de funciones sean claros y concisos. También es útil incluir un ejemplo de trabajo mínimo.
¿Dónde se guardan mis modelos descargados?
Las ubicaciones se utilizan en el orden de
- Vocación
hub.set_dir(
) $TORCH_HOME/hub
, si es variable de entornoTORCH_HOME
Está establecido.$XDG_CACHE_HOME/torch/hub
, si es variable de entornoXDG_CACHE_HOME
Está establecido.~/.cache/torch/hub
torch.hub.get_dir()
[source]-
Obtenga el directorio de caché de Torch Hub utilizado para almacenar modelos y pesos descargados.
Si
set_dir()
no se llama, la ruta predeterminada es$TORCH_HOME/hub
donde variable de entorno$TORCH_HOME
predeterminado a$XDG_CACHE_HOME/torch
.$XDG_CACHE_HOME
sigue la especificación X Design Group del diseño del sistema de archivos Linux, con un valor predeterminado~/.cache
si no se establece la variable de entorno.
torch.hub.set_dir(d)
[source]-
Opcionalmente, configure el directorio de Torch Hub utilizado para guardar modelos y pesos descargados.
- Parámetros
-
D (string): Ruta a una carpeta local para guardar modelos y pesos descargados.
Lógica de almacenamiento en caché
De forma predeterminada, no limpiamos los archivos después de cargarlos. Hub usa la caché de forma predeterminada si ya existe en el directorio devuelto por get_dir()
.
Los usuarios pueden forzar una recarga llamando hub.load(..., force_reload=True)
. Esto eliminará la carpeta github existente y los pesos descargados, reinicializará una nueva descarga. Esto es útil cuando las actualizaciones se publican en la misma sucursal, los usuarios pueden mantenerse al día con la última versión.
Limitaciones conocidas:
Torch hub funciona importando el paquete como si estuviera instalado. Hay algunos efectos secundarios introducidos al importar en Python. Por ejemplo, puede ver nuevos elementos en las cachés de Python sys.modules
y sys.path_importer_cache
que es el comportamiento normal de Python.
Una limitación conocida que vale la pena mencionar aquí es el usuario NO PODER cargar dos ramas diferentes del mismo repositorio en el mismo proceso de Python. Es como instalar dos paquetes con el mismo nombre en Python, lo cual no es bueno. Cache podría unirse a la fiesta y darte sorpresas si realmente lo intentas. Por supuesto, está bien cargarlos en procesos separados.
Al final de todo puedes encontrar las notas de otros programadores, tú de igual manera tienes la habilidad mostrar el tuyo si te apetece.