Skip to main content
Learn how to integrate Guardrails with any custom LLM API that isn’t natively supported.

Overview

If you’re using an LLM that isn’t natively supported by Guardrails and you don’t want to use LiteLLM, you can build a custom LLM API wrapper. This gives you full control over how your LLM integrates with Guardrails.

Basic custom wrapper

Create a function that accepts a prompt as a string and returns the LLM output as a string:
from guardrails import Guard
from guardrails.hub import ProfanityFree

# Create a Guard class
guard = Guard().use(ProfanityFree())

# Function that takes the prompt as a string and returns the LLM output as string
def my_llm_api(**kwargs) -> str:
    """Custom LLM API wrapper.
    
    At least messages should be provided.
    
    Args:
        messages (list[dict]): The message history to be passed to the LLM API
        **kwargs: Any additional arguments to be passed to the LLM API
    
    Returns:
        str: The output of the LLM API
    """
    messages = kwargs.get("messages")
    
    # Call your LLM API here
    llm_output = some_llm(messages, **kwargs)
    
    return llm_output

# Wrap your LLM API call
validated_response = guard(
    my_llm_api,
    messages=[{"role": "user", "content": "Tell me a story"}],
    **kwargs,
)

Example: Custom REST API

Here’s an example of wrapping a custom REST API:
import requests
from guardrails import Guard

def custom_llm_api(**kwargs) -> str:
    """Wrapper for a custom LLM REST API."""
    messages = kwargs.get("messages")
    temperature = kwargs.get("temperature", 0.7)
    max_tokens = kwargs.get("max_tokens", 100)
    
    # Make request to your custom API
    response = requests.post(
        "https://your-llm-api.com/generate",
        json={
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens
        },
        headers={"Authorization": f"Bearer {YOUR_API_KEY}"}
    )
    
    return response.json()["output"]

# Use with Guardrails
guard = Guard()

result = guard(
    custom_llm_api,
    messages=[{"role": "user", "content": "Hello!"}],
    temperature=0.8,
    max_tokens=150
)

print(result.validated_output)

Example: Local model

Here’s an example of wrapping a locally-hosted model:
from transformers import pipeline
from guardrails import Guard
from guardrails.validators import ValidLength, ToxicLanguage
from guardrails import OnFailAction

# Setup pipeline
generator = pipeline("text-generation", model="facebook/opt-350m")

# Create your prompt or starting text
prompt = "What are we having for dinner?"

# Create the Guard
guard = Guard.for_string(
    validators=[
        ValidLength(
            min=48,
            on_fail=OnFailAction.FIX
        ),
        ToxicLanguage(
            on_fail=OnFailAction.FIX
        )
    ],
    prompt=prompt
)

# Run the Guard
response = guard(
    llm_api=generator,
    max_new_tokens=40
)

if response.validation_passed:
    print("validated_output: ", response.validated_output)
else:
    print("error: ", response.error)

Example: Hugging Face models

For Hugging Face text generation models:
from guardrails import Guard
from guardrails.validators import ValidLength, ToxicLanguage
from guardrails import OnFailAction
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Create your prompt or starting text
prompt = "Hello, I'm a language model,"

# Setup torch
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# Instantiate your tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Instantiate your model
model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).to(torch_device)

# Customize your model inputs if desired
model_inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)

# Create the Guard
guard = Guard.for_string(
    validators=[
        ValidLength(
            min=48,
            on_fail=OnFailAction.FIX
        ),
        ToxicLanguage(
            on_fail=OnFailAction.FIX
        )
    ],
    prompt=prompt
)

# Run the Guard
response = guard(
    llm_api=model.generate,
    max_new_tokens=40,
    tokenizer=tokenizer,
    **model_inputs,
)

# Check the output
if response.validation_passed:
    print("validated_output: ", response.validated_output)
else:
    print("error: ", response.error)

Requirements for custom wrappers

Your custom LLM wrapper should:
  1. Accept messages as a parameter (list of message dictionaries)
  2. Return a string containing the LLM’s output
  3. Accept additional keyword arguments that you want to pass to your LLM
  4. Handle errors appropriately

Using with validators

Custom LLM wrappers work seamlessly with all Guardrails validators:
from guardrails import Guard
from guardrails.hub import ProfanityFree, ValidLength

def my_custom_llm(**kwargs) -> str:
    # Your custom LLM implementation
    messages = kwargs.get("messages")
    return your_llm_call(messages)

guard = Guard().use(
    ProfanityFree(),
    ValidLength(min=10, max=100)
)

result = guard(
    my_custom_llm,
    messages=[{"role": "user", "content": "Tell me a joke"}]
)

Structured data with custom LLMs

You can also use custom LLMs with structured data generation:
from pydantic import BaseModel
from guardrails import Guard

class Person(BaseModel):
    name: str
    age: int

def my_custom_llm(**kwargs) -> str:
    # Your custom LLM implementation
    messages = kwargs.get("messages")
    return your_llm_call(messages)

guard = Guard.for_pydantic(Person)

result = guard(
    my_custom_llm,
    messages=[{"role": "user", "content": "Generate a person"}]
)

print(result.validated_output)

Best practices

  1. Error handling: Implement proper error handling in your wrapper
  2. Timeouts: Set appropriate timeouts for API calls
  3. Retries: Consider implementing retry logic for transient failures
  4. Logging: Add logging to help debug issues
  5. Type hints: Use type hints for better code clarity

Alternative: Use LiteLLM

If you don’t want to build a custom wrapper, consider using LiteLLM which supports 100+ LLM providers out of the box.