Skip to content

Modules

Copyright 2023 Bell Eapen

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

BaseAgent

Source code in src/dhti_elixir_base/agent.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class BaseAgent:

    class AgentInput(BaseModel):
        """Chat history with the bot."""
        input: str
        model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True)

    def __init__(
        self,
        name=None,
        description=None,
        llm=None,
        input_type: type[BaseModel] | None = None,
        prefix=None,
        suffix=None,
        tools: List = [],
        mcp = None,
    ):
        self.llm = llm or get_di("function_llm")
        self.prefix = prefix or get_di("prefix")
        self.suffix = suffix or get_di("suffix")
        self.tools = tools
        self._name = (
            name or re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
        )
        self._description = description or f"Agent for {self._name}"
        # current_patient_context = MessagesPlaceholder(variable_name="current_patient_context")
        # memory = ConversationBufferMemory(memory_key="current_patient_context", return_messages=True)
        self.agent_kwargs = {
            "prefix": self.prefix,
            "suffix": self.suffix,
            # "memory_prompts": [current_patient_context],
            "input_variables": ["input", "agent_scratchpad", "current_patient_context"],
        }
        if input_type is None:
            self.input_type = self.AgentInput
        else:
            self.input_type = input_type
        if mcp is not None:
            self.client = MultiServerMCPClient(mcp)

    @property
    def name(self):
        return self._name

    @property
    def description(self):
        return self._description

    @name.setter
    def name(self, value):
        self._name = value

    @description.setter
    def description(self, value):
        self._description = value

    def get_agent(self):
        if self.llm is None:
            raise ValueError("llm must not be None when initializing the agent.")
        return initialize_agent(
            tools=self.tools,
            llm=self.llm,
            agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
            stop=["\nObservation:"],
            max_iterations=len(self.tools) + 3,
            handle_parsing_errors=True,
            agent_kwargs=self.agent_kwargs,
            verbose=True,
        ).with_types(
            input_type=self.input_type # type: ignore
        )

    # ! This is currently supported only for models supporting llm.bind_tools. See function return
    def get_agent_prompt(self):
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "{prefix}"
                    " You have access to the following tools: {tool_names}.\n{system_message}",
                ),
                MessagesPlaceholder(variable_name="messages"),
            ]
        )
        prompt = prompt.partial(prefix=self.prefix)
        prompt = prompt.partial(system_message=self.suffix)
        prompt = prompt.partial(
            tool_names=", ".join([tool.name for tool in self.tools])
        )
        return prompt

    def get_agent_chat_prompt_with_memory(self):
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", "You are a helpful assistant."),
                # First put the history
                ("placeholder", "{chat_history}"),
                # Then the new input
                ("human", "{input}"),
                # Finally the scratchpad
                ("placeholder", "{agent_scratchpad}"),
            ]
        )

    def langgraph_agent(self):
        """Create an agent."""
        prompt = self.get_agent_prompt()
        if not hasattr(self.llm, "bind_tools"):
            raise ValueError(
                "The LLM does not support binding tools. Please use a compatible LLM."
            )
        return prompt | self.llm.bind_tools(self.tools)  # type: ignore

    def get_langgraph_agent_executor(self):
        """Get the agent executor."""
        if self.llm is None:
            raise ValueError("llm must not be None when initializing the agent executor.")
        agent = create_tool_calling_agent(
            llm=self.llm,
            tools=self.tools,
            prompt=self.get_agent_prompt(),
        )
        agent_executor = AgentExecutor(agent=agent, tools=self.tools)
        return agent_executor

    def get_langgraph_agent_executor_with_memory(self):
        from langchain_core.chat_history import InMemoryChatMessageHistory
        from langchain_core.runnables.history import RunnableWithMessageHistory
        if self.llm is None:
            raise ValueError(
                "llm must not be None when initializing the agent executor."
            )
        memory = InMemoryChatMessageHistory()
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", "You are a helpful assistant."),
                # First put the history
                ("placeholder", "{chat_history}"),
                # Then the new input
                ("human", "{input}"),
                # Finally the scratchpad
                ("placeholder", "{agent_scratchpad}"),
            ]
        )
        agent = create_tool_calling_agent(
            llm=self.llm,
            tools=self.tools,
            prompt=prompt,
        )
        agent_executor = AgentExecutor(agent=agent, tools=self.tools)
        return RunnableWithMessageHistory(
            agent_executor,  # type: ignore
            # This is needed because in most real world scenarios, a session id is needed
            # It isn't really used here because we are using a simple in memory ChatMessageHistory
            lambda session_id: memory,
            input_messages_key="input",
            history_messages_key="chat_history",
        )

    async def get_langgraph_mcp_agent(self):
        """Get the agent executor for async execution."""
        if self.llm is None:
            raise ValueError("llm must not be None when initializing the agent executor.")
        if self.client is None:
            raise ValueError("MCP client must not be None when initializing the agent.")
        tools = await self.client.get_tools()
        agent = create_react_agent(
            model=self.llm,
            tools=tools,
        )
        return agent

AgentInput

Bases: BaseModel

Chat history with the bot.

Source code in src/dhti_elixir_base/agent.py
33
34
35
36
class AgentInput(BaseModel):
    """Chat history with the bot."""
    input: str
    model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True)

get_langgraph_agent_executor()

Get the agent executor.

Source code in src/dhti_elixir_base/agent.py
145
146
147
148
149
150
151
152
153
154
155
def get_langgraph_agent_executor(self):
    """Get the agent executor."""
    if self.llm is None:
        raise ValueError("llm must not be None when initializing the agent executor.")
    agent = create_tool_calling_agent(
        llm=self.llm,
        tools=self.tools,
        prompt=self.get_agent_prompt(),
    )
    agent_executor = AgentExecutor(agent=agent, tools=self.tools)
    return agent_executor

get_langgraph_mcp_agent() async

Get the agent executor for async execution.

Source code in src/dhti_elixir_base/agent.py
191
192
193
194
195
196
197
198
199
200
201
202
async def get_langgraph_mcp_agent(self):
    """Get the agent executor for async execution."""
    if self.llm is None:
        raise ValueError("llm must not be None when initializing the agent executor.")
    if self.client is None:
        raise ValueError("MCP client must not be None when initializing the agent.")
    tools = await self.client.get_tools()
    agent = create_react_agent(
        model=self.llm,
        tools=tools,
    )
    return agent

langgraph_agent()

Create an agent.

Source code in src/dhti_elixir_base/agent.py
136
137
138
139
140
141
142
143
def langgraph_agent(self):
    """Create an agent."""
    prompt = self.get_agent_prompt()
    if not hasattr(self.llm, "bind_tools"):
        raise ValueError(
            "The LLM does not support binding tools. Please use a compatible LLM."
        )
    return prompt | self.llm.bind_tools(self.tools)  # type: ignore

Copyright 2024 Bell Eapen

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

BaseLLM

Bases: LLM

Source code in src/dhti_elixir_base/llm.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class BaseLLM(LLM):

    hosted_url: Optional[str] = Field(
        None, alias="hosted_url"
    )  #! Alias is important when inheriting from LLM
    model_name: Optional[str] = Field(None, alias="model_name")
    params: Mapping[str, Any] = Field(default_factory=dict, alias="params")

    backend: Optional[str] = "dhti"
    temperature: Optional[float] = 0.1
    top_p: Optional[float] = 0.8
    top_k: Optional[int] = 40
    n_batch: Optional[int] = 8
    n_threads: Optional[int] = 4
    n_predict: Optional[int] = 256
    max_output_tokens: Optional[int] = 512
    repeat_last_n: Optional[int] = 64
    repeat_penalty: Optional[float] = 1.18

    def __init__(self, hosted_url: str, model_name: str, **kwargs):
        super().__init__(**kwargs)
        self.hosted_url = hosted_url
        self.model_name = model_name
        self.params = {**self._get_model_default_parameters, **kwargs}

    @property
    def _get_model_default_parameters(self):
        return {
            "max_output_tokens": self.max_output_tokens,
            "n_predict": self.n_predict,
            "top_k": self.top_k,
            "top_p": self.top_p,
            "temperature": self.temperature,
            "n_batch": self.n_batch,
            "repeat_penalty": self.repeat_penalty,
            "repeat_last_n": self.repeat_last_n,
        }

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """
        Get all the identifying parameters
        """
        return {
            "model_name": self.model_name,
            "hosted_url": self.hosted_url,
            "model_parameters": self._get_model_default_parameters,
        }

    @property
    def _llm_type(self) -> str:
        return "dhti"

    @abstractmethod
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[Any] = None,
        **kwargs
    ) -> str:
        """
        Args:
            prompt: The prompt to pass into the model.
            stop: A list of strings to stop generation when encountered
            run_manager: Optional run manager for callbacks and tracing

        Returns:
            The string generated by the model
        """

        pass

BaseMCPServer

Bases: FastMCP

Base class for MCP servers, extending FastMCP for custom functionality.

Source code in src/dhti_elixir_base/mcp.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class BaseMCPServer(FastMCP):
    """Base class for MCP servers, extending FastMCP for custom functionality."""

    def __init__(self, name: str | None = None):
        self._name = name or "BaseMCPServer"
        super().__init__(name=self._name)

    @property
    def name(self):
        """Return the name of this MCP server instance."""
        return self._name

name property

Return the name of this MCP server instance.

BaseModel

Bases: ABC

A model class to lead the model and tokenizer

Source code in src/dhti_elixir_base/model.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class BaseModel(ABC):
    """A model class to lead the model and tokenizer"""

    model: Any = None

    def __init__(
        self,
        model: Any,
    ) -> None:
        self.model = model

    @classmethod
    @abstractmethod
    def load(cls) -> None:
        if cls.model is None:
            log.info("Loading model")
            t0 = perf_counter()
            # Load the model here
            elapsed = 1000 * (perf_counter() - t0)
            log.info("Model warm-up time: %d ms.", elapsed)
        else:
            log.info("Model is already loaded")

    @classmethod
    @abstractmethod
    def predict(cls, input: Any, **kwargs) -> Any:
        assert input is not None and cls.model is not None  # Sanity check

        # Make sure the model is loaded.
        cls.load()
        t0 = perf_counter()
        # Predict here
        elapsed = 1000 * (perf_counter() - t0)
        log.info("Model prediction time: %d ms.", elapsed)
        return None

BaseServer

Bases: ABC

A server class to load the model and tokenizer

Source code in src/dhti_elixir_base/server.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class BaseServer(ABC):
    """A server class to load the model and tokenizer"""

    class RequestSchema(BaseModel):
        text: str = Field()
        labels: list = Field()
        required: list = Field()

    class ResponseSchema(BaseModel):
        text: str = Field()

    request_schema = RequestSchema
    response_schema = ResponseSchema

    def __init__(
        self, model: BaseModel, request_schema: Any = None, response_schema: Any = None
    ) -> None:
        self.model = model
        if request_schema is not None:
            self.request_schema = request_schema
        if response_schema is not None:
            self.response_schema = response_schema

    @property
    def name(self):
        return re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()

    def health_check(self) -> Any:
        """Health check endpoint"""
        self.model.load()
        return {"status": "ok"}

    def get_schema(self) -> Any:
        """Get the request schema"""
        return self.request_schema

    def predict(self, input: Any, **kwargs) -> Any:
        _input = self.request_schema(**input)  # type: ignore
        _result = self.model.predict(_input, **kwargs)
        result = self.response_schema(**_result)  # type: ignore
        return result

get_schema()

Get the request schema

Source code in src/dhti_elixir_base/server.py
48
49
50
def get_schema(self) -> Any:
    """Get the request schema"""
    return self.request_schema

health_check()

Health check endpoint

Source code in src/dhti_elixir_base/server.py
43
44
45
46
def health_check(self) -> Any:
    """Health check endpoint"""
    self.model.load()
    return {"status": "ok"}

BaseSpace

Bases: Agent

Source code in src/dhti_elixir_base/space.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class BaseSpace(Agent):

    from typing import Optional

    def __init__(self, agent: Optional[BaseAgent] = None, *args, **kwargs):
        if agent:
            self.agent = agent.get_agent()
            super().__init__(id=agent.name, *args, **kwargs)

    @action
    def say(self, content: str, current_patient_context: str = ""):
        """Search for a patient in the FHIR database."""
        #! TODO: Needs bootstrapping here.

        message = {
            "input": content,
            "current_patient_context": current_patient_context,
        }
        response_content = self.agent.invoke(message)
        self.send(
            {
                "to": self.current_message()["from"], # type: ignore
                "action": {
                    "name": "say",
                    "args": {
                        "content": response_content["output"],
                    },
                },
            }
        )
        return True

say(content, current_patient_context='')

Search for a patient in the FHIR database.

Source code in src/dhti_elixir_base/space.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@action
def say(self, content: str, current_patient_context: str = ""):
    """Search for a patient in the FHIR database."""
    #! TODO: Needs bootstrapping here.

    message = {
        "input": content,
        "current_patient_context": current_patient_context,
    }
    response_content = self.agent.invoke(message)
    self.send(
        {
            "to": self.current_message()["from"], # type: ignore
            "action": {
                "name": "say",
                "args": {
                    "content": response_content["output"],
                },
            },
        }
    )
    return True