text2vec-cohere
In shortโ
- This module uses a third-party API and may incur costs.
- Make sure to check the Cohere pricing page before vectorizing large amounts of data.
- Weaviate automatically parallelizes requests to the Cohere-API when using the batch endpoint.
Introductionโ
The text2vec-cohere
module enables you to use Cohere embeddings in Weaviate to represent data objects and run semantic (nearText
) queries.
How to enableโ
Request a Cohere API-key via their dashboard.
Weaviate Cloud Servicesโ
This module is enabled by default on the WCS.
Weaviate open sourceโ
You can find an example Docker-compose file below, which will spin up Weaviate with the Cohere module.
---
version: '3.4'
services:
weaviate:
image: semitechnologies/weaviate:1.19.6
restart: on-failure:0
ports:
- "8080:8080"
environment:
QUERY_DEFAULTS_LIMIT: 20
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
PERSISTENCE_DATA_PATH: "./data"
DEFAULT_VECTORIZER_MODULE: text2vec-cohere
ENABLE_MODULES: text2vec-cohere
COHERE_APIKEY: sk-foobar # request a key on cohere.com, setting this parameter is optional, you can also provide the API key at runtime
CLUSTER_HOSTNAME: 'node1'
...
- You can also use the Weaviate configuration tool to create a Weaviate setup with this module.
- The
COHERE_APIKEY
environment variable is optional and you can instead provide the key at insert/query time as an HTTP header (see the 'usage' section for instructions)
How to configureโ
In your Weaviate schema, you must define how you want this module to vectorize your data. If you are new to Weaviate schemas, you might want to check out the tutorial on the Weaviate schema first.
The following schema configuration tells Weaviate to vectorize the Document
class with text2vec-cohere
, using the multilingual-22-12
model and without input truncation by the Cohere API.
The multilingual models use dot product, and the English model uses cosine. Make sure to set this accordingly in your Weaviate schema. You can see supported distance metrics here.
{
"classes": [
{
"class": "Document",
"description": "A class called document",
"vectorizer": "text2vec-cohere",
"vectorIndexConfig": {
"distance": "dot" // <== Cohere models use dot product instead of the Weaviate default cosine
},
"moduleConfig": {
"text2vec-cohere": {
"model": "multilingual-22-12", // <== defaults to multilingual-22-12 if not set
"truncate": "RIGHT" // <== defaults to RIGHT if not set
}
},
"properties": [
{
"dataType": [
"text"
],
"description": "Content that will be vectorized",
"moduleConfig": {
"text2vec-cohere": {
"skip": false,
"vectorizePropertyName": false
}
},
"name": "content"
}
]
}
]
}
Usageโ
- If the Cohere API key is not set in the
text2vec-cohere
module, you can set the API key at query time by adding the following to the HTTP header:X-Cohere-Api-Key: YOUR-COHERE-API-KEY
. - Using this module will enable GraphQL vector search operators.
Exampleโ
- GraphQL
- Python
- JavaScript
- Go
- Java
- Curl
{
Get{
Publication(
nearText: {
concepts: ["fashion"],
distance: 0.6 # prior to v1.14 use "certainty" instead of "distance"
moveAwayFrom: {
concepts: ["finance"],
force: 0.45
},
moveTo: {
concepts: ["haute couture"],
force: 0.85
}
}
){
name
_additional {
certainty # only supported if distance==cosine.
distance # always supported
}
}
}
}
import weaviate
client = weaviate.Client(
url="http://localhost:8080",
additional_headers={
"X-Cohere-Api-Key": "YOUR-COHERE-API-KEY"
}
)
nearText = {
"concepts": ["fashion"],
"distance": 0.6, # prior to v1.14 use "certainty" instead of "distance"
"moveAwayFrom": {
"concepts": ["finance"],
"force": 0.45
},
"moveTo": {
"concepts": ["haute couture"],
"force": 0.85
}
}
result = (
client.query
.get("Publication", "name")
.with_additional(["certainty OR distance"]) # note that certainty is only supported if distance==cosine
.with_near_text(nearText)
.do()
)
print(result)
const weaviate = require('weaviate-client');
const client = weaviate.client({
scheme: 'http',
host: 'localhost:8080',
headers: {'X-Cohere-Api-Key': 'YOUR-COHERE-API-KEY'},
});
client.graphql
.get()
.withClassName('Publication')
.withFields('name _additional{certainty distance}') // note that certainty is only supported if distance==cosine
.withNearText({
concepts: ['fashion'],
distance: 0.6, // prior to v1.14 use certainty instead of distance
moveAwayFrom: {
concepts: ['finance'],
force: 0.45
},
moveTo: {
concepts: ['haute couture'],
force: 0.85
}
})
.do()
.then(console.log)
.catch(console.error);
package main
import (
"context"
"fmt"
"github.com/weaviate/weaviate-go-client/v4/weaviate"
"github.com/weaviate/weaviate-go-client/v4/weaviate/graphql"
)
func main() {
cfg := weaviate.Config{
Host: "localhost:8080",
Scheme: "http",
Headers: map[string]string{"X-Cohere-Api-Key": "YOUR-COHERE-API-KEY"},
}
client, err := weaviate.NewClient(cfg)
if err != nil {
panic(err)
}
className := "Publication"
name := graphql.Field{Name: "name"}
_additional := graphql.Field{
Name: "_additional", Fields: []graphql.Field{
{Name: "certainty"}, // only supported if distance==cosine
{Name: "distance"}, // always supported
},
}
concepts := []string{"fashion"}
distance := float32(0.6)
moveAwayFrom := &graphql.MoveParameters{
Concepts: []string{"finance"},
Force: 0.45,
}
moveTo := &graphql.MoveParameters{
Concepts: []string{"haute couture"},
Force: 0.85,
}
nearText := client.GraphQL().NearTextArgBuilder().
WithConcepts(concepts).
WithDistance(distance). // use WithCertainty(certainty) prior to v1.14
WithMoveTo(moveTo).
WithMoveAwayFrom(moveAwayFrom)
ctx := context.Background()
result, err := client.GraphQL().Get().
WithClassName(className).
WithFields(name, _additional).
WithNearText(nearText).
Do(ctx)
if err != nil {
panic(err)
}
fmt.Printf("%v", result)
}
package io.weaviate;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.argument.NearTextArgument;
import io.weaviate.client.v1.graphql.query.argument.NearTextMoveParameters;
import io.weaviate.client.v1.graphql.query.fields.Field;
import java.util.HashMap;
import java.util.Map;
public class App {
public static void main(String[] args) {
Map<String, String> headers = new HashMap<String, String>() { {
put("X-Cohere-Api-Key", "YOUR-COHERE-API-KEY");
} };
Config config = new Config("http", "localhost:8080", headers);
WeaviateClient client = new WeaviateClient(config);
NearTextMoveParameters moveTo = NearTextMoveParameters.builder()
.concepts(new String[]{ "haute couture" }).force(0.85f).build();
NearTextMoveParameters moveAway = NearTextMoveParameters.builder()
.concepts(new String[]{ "finance" }).force(0.45f)
.build();
NearTextArgument nearText = client.graphQL().arguments().nearTextArgBuilder()
.concepts(new String[]{ "fashion" })
.distance(0.6f) // use .certainty(0.7f) prior to v1.14
.moveTo(moveTo)
.moveAwayFrom(moveAway)
.build();
Field name = Field.builder().name("name").build();
Field _additional = Field.builder()
.name("_additional")
.fields(new Field[]{
Field.builder().name("certainty").build(), // only supported if distance==cosine
Field.builder().name("distance").build(), // always supported
}).build();
Result<GraphQLResponse> result = client.graphQL().get()
.withClassName("Publication")
.withFields(name, _additional)
.withNearText(nearText)
.run();
if (result.hasErrors()) {
System.out.println(result.getError());
return;
}
System.out.println(result.getResult());
}
}
$ echo '{
"query": "{
Get{
Publication(
nearText: {
concepts: [\"fashion\"],
distance: 0.6, // use certainty instead of distance prior to v1.14
moveAwayFrom: {
concepts: [\"finance\"],
force: 0.45
},
moveTo: {
concepts: [\"haute couture\"],
force: 0.85
}
}
){
name
_additional {
certainty // only supported if distance==cosine
distance // always supported
}
}
}
}"
}' | curl \
-X POST \
-H 'Content-Type: application/json' \
-H "X-Cohere-Api-Key: YOUR-COHERE-API-KEY" \
-d @- \
http://localhost:8080/v1/graphql
Additional informationโ
Available modelsโ
Weaviate defaults to Cohere's multilingual-22-12
embedding model unless specified otherwise.
For example, the following schema configuration will set Weaviate to vectorize the Document
class with text2vec-cohere
using the multilingual-22-12
model.
{
"classes": [
{
"class": "Document",
"description": "A class called document",
"vectorizer": "text2vec-cohere",
"vectorIndexConfig": {
"distance": "dot"
},
"moduleConfig": {
"text2vec-cohere": {
"model": "multilingual-22-12"
}
Truncationโ
If the input text contains too many tokens and is not truncated, the API will throw an error. The Cohere API can be set to automatically truncate your input text.
You can set the truncation option with the truncate
parameter to RIGHT
or NONE
. Passing RIGHT will discard the right side of the input, the remaining input is exactly the maximum input token length for the model. source
- The upside of truncating is that a batch import always succeeds.
- The downside of truncating (i.e.,
NONE
) is that a large text will be partially vectorized without the user being made aware of the truncation.
Cohere Rate Limitsโ
Because you will be getting embeddings based on your own API key, you will be dealing with rate limits applied to your account. More information about Cohere rate limits can be found here.
Throttle the import inside your applicationโ
If you run into rate limits, you can also decide to throttle the import in your application.
E.g., in Python and Go using the Weaviate client.
- Python
- Go
from weaviate import Client
import time
def configure_batch(client: Client, batch_size: int, batch_target_rate: int):
"""
Configure the weaviate client's batch so it creates objects at `batch_target_rate`.
Parameters
----------
client : Client
The Weaviate client instance.
batch_size : int
The batch size.
batch_target_rate : int
The batch target rate as # of objects per second.
"""
def callback(batch_results: dict) -> None:
# you could print batch errors here
time_took_to_create_batch = batch_size * (client.batch.creation_time/client.batch.recommended_num_objects)
time.sleep(
max(batch_size/batch_target_rate - time_took_to_create_batch + 1, 0)
)
client.batch.configure(
batch_size=batch_size,
timeout_retries=5,
callback=callback,
)
package main
import (
"context"
"time"
"github.com/weaviate/weaviate-go-client/v4/weaviate"
"github.com/weaviate/weaviate/entities/models"
)
var (
// adjust to your liking
targetRatePerMin = 600
batchSize = 50
)
func main() {
cfg := weaviate.Config{
Host: "localhost:8080",
Scheme: "http",
}
client, err := weaviate.NewClient(cfg)
if err != nil {
panic(err)
}
// replace those 10000 empty objects with your actual data
objects := make([]*models.Object, 10000)
// we aim to send one batch every tickInterval second.
tickInterval := time.Duration(batchSize/targetRatePerMinute) * time.Minute
t := time.NewTicker(tickInterval)
before := time.Now()
for i := 0; i < len(objects); i += batchSize {
// create a fresh batch
batch := client.Batch().ObjectsBatcher()
// add batchSize objects to the batch
for j := i; j < i+batchSize; j++ {
batch = batch.WithObject(objects[i+j])
}
// send off batch
res, err := batch.Do(context.Background())
// TODO: inspect result for individual errors
_ = res
// TODO: check request error
_ = err
// we wait for the next tick. If the previous batch took longer than
// tickInterval, we won't need to wait, effectively making this an
// unthrottled import.
<-t.C
}
}
More resourcesโ
If you can't find the answer to your question here, please look at the:
- Frequently Asked Questions. Or,
- Knowledge base of old issues. Or,
- For questions: Stackoverflow. Or,
- For more involved discussion: Weaviate Community Forum. Or,
- We also have a Slack channel.