Skip to content

Commit 7c152de

Browse files
committed
Introduce include function and supporting plumbing
that was previously implemented in a branch of `replicate/cog`.
1 parent 74b41cc commit 7c152de

File tree

3 files changed

+462
-0
lines changed

3 files changed

+462
-0
lines changed

replicate/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from replicate.client import Client
2+
from replicate.include import include as _include
23
from replicate.pagination import async_paginate as _async_paginate
34
from replicate.pagination import paginate as _paginate
45

@@ -21,3 +22,5 @@
2122
predictions = default_client.predictions
2223
trainings = default_client.trainings
2324
webhooks = default_client.webhooks
25+
26+
include = _include

replicate/include.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import os
2+
import sys
3+
from contextlib import contextmanager
4+
from contextvars import ContextVar
5+
from dataclasses import dataclass
6+
from typing import Any, Callable, Dict, Literal, Optional, Tuple
7+
8+
import replicate
9+
10+
from .exceptions import ModelError
11+
from .model import Model
12+
from .prediction import Prediction
13+
from .run import _has_output_iterator_array_type
14+
from .version import Version
15+
16+
__all__ = ["include"]
17+
18+
19+
_RUN_STATE: ContextVar[Literal["load", "setup", "run"] | None] = ContextVar(
20+
"run_state",
21+
default=None,
22+
)
23+
_RUN_TOKEN: ContextVar[str | None] = ContextVar("run_token", default=None)
24+
25+
26+
@contextmanager
27+
def run_state(state: Literal["load", "setup", "run"]) -> Any:
28+
"""
29+
Internal context manager for execution state.
30+
"""
31+
s = _RUN_STATE.set(state)
32+
try:
33+
yield
34+
finally:
35+
_RUN_STATE.reset(s)
36+
37+
38+
@contextmanager
39+
def run_token(token: str) -> Any:
40+
"""
41+
Sets the API token for the current context.
42+
"""
43+
t = _RUN_TOKEN.set(token)
44+
try:
45+
yield
46+
finally:
47+
_RUN_TOKEN.reset(t)
48+
49+
50+
def _find_api_token() -> str:
51+
token = os.environ.get("REPLICATE_API_TOKEN")
52+
if token:
53+
print("Using Replicate API token from environment", file=sys.stderr)
54+
return token
55+
56+
token = _RUN_TOKEN.get()
57+
58+
if not token:
59+
raise ValueError("No run token found")
60+
61+
return token
62+
63+
64+
@dataclass
65+
class Run:
66+
"""
67+
Represents a running prediction with access to its version.
68+
"""
69+
70+
prediction: Prediction
71+
version: Version
72+
73+
def wait(self) -> Any:
74+
"""
75+
Wait for the prediction to complete and return its output.
76+
"""
77+
self.prediction.wait()
78+
79+
if self.prediction.status == "failed":
80+
raise ModelError(self.prediction)
81+
82+
if _has_output_iterator_array_type(self.version):
83+
return "".join(self.prediction.output)
84+
85+
return self.prediction.output
86+
87+
def logs(self) -> Optional[str]:
88+
"""
89+
Fetch and return the logs from the prediction.
90+
"""
91+
self.prediction.reload()
92+
93+
return self.prediction.logs
94+
95+
96+
@dataclass
97+
class Function:
98+
"""
99+
A wrapper for a Replicate model that can be called as a function.
100+
"""
101+
102+
function_ref: str
103+
104+
def _client(self) -> replicate.Client:
105+
return replicate.Client(api_token=_find_api_token())
106+
107+
def _split_function_ref(self) -> Tuple[str, str, Optional[str]]:
108+
owner, name = self.function_ref.split("/")
109+
name, version = name.split(":") if ":" in name else (name, None)
110+
return owner, name, version
111+
112+
def _model(self) -> Model:
113+
client = self._client()
114+
model_owner, model_name, _ = self._split_function_ref()
115+
return client.models.get(f"{model_owner}/{model_name}")
116+
117+
def _version(self) -> Version:
118+
client = self._client()
119+
model_owner, model_name, model_version = self._split_function_ref()
120+
model = client.models.get(f"{model_owner}/{model_name}")
121+
version = (
122+
model.versions.get(model_version) if model_version else model.latest_version
123+
)
124+
return version
125+
126+
def __call__(self, **inputs: Dict[str, Any]) -> Any:
127+
run = self.start(**inputs)
128+
return run.wait()
129+
130+
def start(self, **inputs: Dict[str, Any]) -> Run:
131+
"""
132+
Start a prediction with the specified inputs.
133+
"""
134+
version = self._version()
135+
prediction = self._client().predictions.create(version=version, input=inputs)
136+
print(f"Running {self.function_ref}: https://replicate.com/p/{prediction.id}")
137+
138+
return Run(prediction, version)
139+
140+
@property
141+
def default_example(self) -> Optional[Prediction]:
142+
"""
143+
Get the default example for this model.
144+
"""
145+
return self._model().default_example
146+
147+
@property
148+
def openapi_schema(self) -> dict[Any, Any]:
149+
"""
150+
Get the OpenAPI schema for this model version.
151+
"""
152+
return self._version().openapi_schema
153+
154+
155+
def include(function_ref: str) -> Callable[..., Any]:
156+
"""
157+
Include a Replicate model as a function.
158+
159+
This function can only be called at the top level.
160+
"""
161+
if _RUN_STATE.get() != "load":
162+
raise RuntimeError(
163+
"You may only call cog.ext.pipelines.include at the top level."
164+
)
165+
166+
return Function(function_ref)

0 commit comments

Comments
 (0)