Skip to content

Commit

Permalink
Prefer reading prompt from Valohai parameters.json
Browse files Browse the repository at this point in the history
  • Loading branch information
ruksi committed Nov 19, 2024
1 parent 28114fc commit ce2878e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions inference-mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@ def decode(self, model_outputs):


def run(args):
prompt = valohai.parameters('prompt', args.prompt).value
if not prompt:
raise ValueError('--prompt argument is required when running outside of Valohai')

inference = ModelInference(
base_mistral_model=args.base_mistral_model,
checkpoint_path=args.checkpoint_path,
)
response = inference.generate_response(
prompt=args.prompt,
prompt=prompt,
max_tokens=args.max_tokens,
)
print('Generated Response:')
Expand All @@ -71,7 +75,7 @@ def main():
parser.add_argument('--base_mistral_model', type=str, default='mistralai/Mistral-7B-v0.1', help='Mistral model path or id from Hugging Face')
parser.add_argument('--checkpoint_path', type=str)
parser.add_argument('--max_tokens', type=int, default=305, help='Maximum number of tokens in response')
parser.add_argument('--prompt', type=str, required=True, help='Input prompt for text generation')
parser.add_argument('--prompt', type=str, help='Input prompt for text generation')
# fmt: on
args = parser.parse_args()

Expand Down

0 comments on commit ce2878e

Please sign in to comment.