-
-
Notifications
You must be signed in to change notification settings - Fork 137
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add list and detail endpoints for known model metadata
Add `metadata` arg to /v2/status/models endpoint to request model metadata for available models Refactor model_reference.py
- Loading branch information
ceruleandeep
committed
Sep 27, 2024
1 parent
aaba202
commit 173ce91
Showing
6 changed files
with
470 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,6 +113,7 @@ venv/ | |
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
.env_docker | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,11 @@ | |
# | ||
# SPDX-License-Identifier: AGPL-3.0-or-later | ||
|
||
from flask_restx import fields, reqparse | ||
from flask_restx import Namespace, fields, reqparse | ||
|
||
from horde.enums import WarningMessage | ||
from horde.exceptions import KNOWN_RC | ||
from horde.model_reference import KnownImageModelRef, KnownTextModelRef | ||
from horde.vars import horde_noun, horde_title | ||
|
||
|
||
|
@@ -260,7 +261,7 @@ def __init__(self): | |
|
||
|
||
class Models: | ||
def __init__(self, api): | ||
def __init__(self, api: Namespace): | ||
self.response_model_wp_status_lite = api.model( | ||
"RequestStatusCheck", | ||
{ | ||
|
@@ -406,7 +407,7 @@ def __init__(self, api): | |
min=0, | ||
), | ||
"untrusted": fields.Integer( | ||
description=("How many waiting requests were skipped because they demanded a trusted worker which this worker is not."), | ||
description="How many waiting requests were skipped because they demanded a trusted worker which this worker is not.", | ||
min=0, | ||
), | ||
"models": fields.Integer( | ||
|
@@ -618,11 +619,11 @@ def __init__(self, api): | |
"forms": fields.List(fields.String(description="Which forms this worker if offering.")), | ||
"team": fields.Nested( | ||
self.response_model_team_details_lite, | ||
"The Team to which this worker is dedicated.", | ||
description="The Team to which this worker is dedicated.", | ||
), | ||
"contact": fields.String( | ||
example="[email protected]", | ||
description=("(Privileged) Contact details for the horde admins to reach the owner of this worker in emergencies."), | ||
description="(Privileged) Contact details for the horde admins to reach the owner of this worker in emergencies.", | ||
min_length=5, | ||
max_length=500, | ||
), | ||
|
@@ -1053,7 +1054,7 @@ def __init__(self, api): | |
max=10, | ||
), | ||
"worker_invited": fields.Integer( | ||
description=("Set to the amount of workers this user is allowed to join to the horde when in worker invite-only mode."), | ||
description="Set to the amount of workers this user is allowed to join to the horde when in worker invite-only mode.", | ||
), | ||
"moderator": fields.Boolean( | ||
example=False, | ||
|
@@ -1278,6 +1279,232 @@ def __init__(self, api): | |
), | ||
}, | ||
) | ||
self.response_model_known_model_md = api.model( | ||
"KnownModelMetadata", | ||
{ | ||
"name": fields.String(description="The name of this model."), | ||
"description": fields.String(description="The description of this model."), | ||
"version": fields.String(description="The version of this model."), | ||
"style": fields.String(description="The style of this model."), | ||
"nsfw": fields.Boolean(description="Whether this model can generate NSFW content."), | ||
"baseline": fields.String(description="The baseline model used for this model."), | ||
}, | ||
) | ||
|
||
settings = api.model( | ||
"KnownTextModelSettings", | ||
{ | ||
"n": fields.Integer(example=1, min=1, max=20), | ||
"frmtadsnsp": fields.Boolean( | ||
example=False, | ||
description=( | ||
"Input formatting option. When enabled, adds a leading space to your input " | ||
"if there is no trailing whitespace at the end of the previous action." | ||
), | ||
), | ||
"frmtrmblln": fields.Boolean( | ||
example=False, | ||
description=( | ||
"Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines " | ||
"in the output with one newline." | ||
), | ||
), | ||
"frmtrmspch": fields.Boolean( | ||
example=False, | ||
description=r"Output formatting option. When enabled, removes #/@%}{+=~|\^<> from the output.", | ||
), | ||
"frmttriminc": fields.Boolean( | ||
example=False, | ||
description=( | ||
"Output formatting option. When enabled, removes some characters from the end of the output such " | ||
"that the output doesn't end in the middle of a sentence. " | ||
"If the output is less than one sentence long, does nothing." | ||
), | ||
), | ||
"max_context_length": fields.Integer( | ||
min=80, | ||
example=1024, | ||
max=32000, | ||
description="Maximum number of tokens to send to the model.", | ||
), | ||
"max_length": fields.Integer( | ||
min=16, | ||
max=1024, | ||
example=80, | ||
description="Number of tokens to generate.", | ||
), | ||
"rep_pen": fields.Float(description="Base repetition penalty value.", min=1, max=3), | ||
"rep_pen_range": fields.Integer(description="Repetition penalty range.", min=0, max=4096), | ||
"rep_pen_slope": fields.Float(description="Repetition penalty slope.", min=0, max=10), | ||
"singleline": fields.Boolean( | ||
example=False, | ||
description=( | ||
"Output formatting option. When enabled, removes everything after the first line of the output, " | ||
"including the newline." | ||
), | ||
), | ||
"temperature": fields.Float(description="Temperature value.", min=0, max=5.0), | ||
"tfs": fields.Float(description="Tail free sampling value.", min=0.0, max=1.0), | ||
"top_a": fields.Float(description="Top-a sampling value.", min=0.0, max=1.0), | ||
"top_k": fields.Integer(description="Top-k sampling value.", min=0, max=100), | ||
"top_p": fields.Float(description="Top-p sampling value.", min=0.001, max=1.0), | ||
"typical": fields.Float(description="Typical sampling value.", min=0.0, max=1.0), | ||
"sampler_order": fields.List( | ||
fields.Integer(description="Array of integers representing the sampler order to be used."), | ||
), | ||
"use_default_badwordsids": fields.Boolean( | ||
example=True, | ||
description="When True, uses the default KoboldAI bad word IDs.", | ||
), | ||
"stop_sequence": fields.List( | ||
fields.String( | ||
description=( | ||
"An array of string sequences whereby the model will stop generating further tokens. " | ||
"The returned text WILL contain the stop sequence." | ||
), | ||
), | ||
), | ||
"min_p": fields.Float(description="Min-p sampling value.", min=0.0, example=0.0, max=1.0), | ||
"smoothing_factor": fields.Float( | ||
description="Quadratic sampling value.", | ||
min=0.0, | ||
example=0.0, | ||
max=10.0, | ||
), | ||
"dynatemp_range": fields.Float( | ||
description="Dynamic temperature range value.", | ||
min=0.0, | ||
example=0.0, | ||
max=5.0, | ||
), | ||
"dynatemp_exponent": fields.Float( | ||
description="Dynamic temperature exponent value.", | ||
min=0.0, | ||
example=1.0, | ||
max=5.0, | ||
), | ||
}, | ||
) | ||
|
||
self.response_model_known_text_model_md = api.inherit( | ||
"KnownTextModelMetadata", | ||
self.response_model_known_model_md, | ||
{ | ||
"parameters": fields.Integer(description="The number of parameters in this model."), | ||
"display_name": fields.String(description="The display name of this model."), | ||
"homepage": fields.String(description="The homepage of the model.", attribute="url"), | ||
"tags": fields.List(fields.String(description="The tags of this model.")), | ||
"instruct_format": fields.String(description="Instruct format template to use for this model."), | ||
"settings": fields.Nested( | ||
settings, | ||
description="Recommended settings for this model.", | ||
allow_null=False, | ||
skip_none=True, | ||
), | ||
}, | ||
) | ||
|
||
requirements = api.model( | ||
"KnownImageModelRequirements", | ||
{ | ||
"clip_skip": fields.Integer( | ||
description="The number of steps to skip in CLIP.", | ||
min=1, | ||
example=1, | ||
), | ||
"min_steps": fields.Integer( | ||
description="The minimum number of steps to take.", | ||
min=1, | ||
example=30, | ||
), | ||
"max_steps": fields.Integer( | ||
description="The maximum number of steps to take.", | ||
min=1, | ||
example=30, | ||
), | ||
"cfg_scale": fields.Float( | ||
description="Classifier-free guidance scale.", | ||
min=0.0, | ||
example=7.5, | ||
), | ||
"min_cfg_scale": fields.Float( | ||
description="Minimum classifier-free guidance scale.", | ||
min=0.0, | ||
example=7.5, | ||
), | ||
"max_cfg_scale": fields.Float( | ||
description="Maximum classifier-free guidance scale.", | ||
min=0.0, | ||
example=7.5, | ||
), | ||
"samplers": fields.List( | ||
fields.String( | ||
description="The samplers to use for this model.", | ||
example="k_euler_a", | ||
), | ||
), | ||
}, | ||
) | ||
|
||
download = api.model( | ||
"KnownImageModelDownload", | ||
{ | ||
"file_name": fields.String(description="The filename of the file to download."), | ||
"file_path": fields.String(description="The path to the file to download."), | ||
"file_url": fields.String(description="The URL to download the file from."), | ||
}, | ||
) | ||
|
||
config = api.model( | ||
"KnownImageModelConfig", | ||
{ | ||
"files": fields.List(fields.Nested(api.model("KnownImageModelFile", {"path": fields.String, "sha256sum": fields.String}))), | ||
"download": fields.List(fields.Nested(download)), | ||
}, | ||
) | ||
self.response_model_known_image_model_md = api.inherit( | ||
"KnownImageModelMetadata", | ||
self.response_model_known_model_md, | ||
{ | ||
"homepage": fields.String(description="The URL of the model's page."), | ||
"weight_type": fields.String(description="Storage format of the model weights.", attribute="type"), | ||
"inpainting": fields.Boolean(description="Whether this model can generate inpainting content."), | ||
"requirements": fields.Nested( | ||
requirements, | ||
description="Generation settings requirements for this model.", | ||
allow_null=False, | ||
skip_none=True, | ||
), | ||
"config": fields.Nested( | ||
config, | ||
description="The configuration of the model.", | ||
allow_null=False, | ||
skip_none=True, | ||
), | ||
"features_not_supported": fields.List(fields.String(description="The features not supported by the model.")), | ||
"size_on_disk_bytes": fields.Integer(description="The size of the model on disk in bytes."), | ||
}, | ||
) | ||
|
||
self.response_model_known_model = api.model( | ||
"KnownModel", | ||
{ | ||
"name": fields.String(description="The name of this model."), | ||
"type": fields.String( | ||
description="Model type (text or image).", | ||
enum=["text", "image"], | ||
), | ||
"metadata": fields.Polymorph( | ||
{ | ||
KnownImageModelRef: self.response_model_known_image_model_md, | ||
KnownTextModelRef: self.response_model_known_text_model_md, | ||
}, | ||
description="The metadata of the model.", | ||
skip_none=True, | ||
), | ||
}, | ||
) | ||
|
||
self.response_model_active_model = api.inherit( | ||
"ActiveModel", | ||
self.response_model_active_model_lite, | ||
|
@@ -1291,8 +1518,17 @@ def __init__(self, api): | |
description="The model type (text or image).", | ||
enum=["image", "text"], | ||
), | ||
"metadata": fields.Polymorph( | ||
{ | ||
KnownImageModelRef: self.response_model_known_image_model_md, | ||
KnownTextModelRef: self.response_model_known_text_model_md, | ||
}, | ||
description="The metadata of the model.", | ||
skip_none=True, | ||
), | ||
}, | ||
) | ||
|
||
self.response_model_deleted_worker = api.model( | ||
"DeletedWorker", | ||
{ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.