Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first two copilot integration tests #523

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data/malicious.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
{"name":"malicious-go-dummy","type":"go","description":"Dummy malicious to test with simple package name on go"}
{"name":"@prefix/malicious-crates-dummy","type":"crates","description":"Dummy malicious to test with encoded package name on crates"}
{"name":"malicious-crates-dummy","type":"crates","description":"Dummy malicious to test with simple package name on crates"}
{"name":"invokehttp","type":"pypi","description":"Invokehttp is a malicious package"}
27 changes: 26 additions & 1 deletion src/codegate/providers/copilot/mapping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, HttpUrl
from pydantic_settings import BaseSettings
Expand Down Expand Up @@ -43,3 +45,26 @@ class CoPilotMappings(BaseSettings):
VALIDATED_ROUTES: List[CopilotProxyRoute] = [
CopilotProxyRoute(path=path, target=target) for path, target in mappings.PROXY_ROUTES
]


class PipelineType(Enum):
FIM = "fim"
CHAT = "chat"


@dataclass
class PipelineRoute:
path: str
pipeline_type: PipelineType
target_url: Optional[str] = None


PIPELINE_ROUTES = [
PipelineRoute(
path="v1/chat/completions",
# target_url="https://api.openai.com/v1/chat/completions",
pipeline_type=PipelineType.CHAT,
),
PipelineRoute(path="v1/engines/copilot-codex/completions", pipeline_type=PipelineType.FIM),
PipelineRoute(path="chat/completions", pipeline_type=PipelineType.CHAT),
]
105 changes: 96 additions & 9 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.output import OutputPipelineInstance
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
from codegate.providers.copilot.mapping import PIPELINE_ROUTES, VALIDATED_ROUTES, PipelineType
from codegate.providers.copilot.pipeline import (
CopilotChatPipeline,
CopilotFimPipeline,
Expand Down Expand Up @@ -153,12 +153,18 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
self.context_tracking: Optional[PipelineContext] = None

def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
if method == "POST" and path == "v1/engines/copilot-codex/completions":
logger.debug("Selected CopilotFimStrategy")
return CopilotFimPipeline(self.pipeline_factory)
if method == "POST" and path == "chat/completions":
logger.debug("Selected CopilotChatStrategy")
return CopilotChatPipeline(self.pipeline_factory)
if method != "POST":
logger.debug("Not a POST request, no pipeline selected")
return None

for route in PIPELINE_ROUTES:
if path == route.path:
if route.pipeline_type == PipelineType.FIM:
logger.debug("Selected FIM pipeline")
return CopilotFimPipeline(self.pipeline_factory)
elif route.pipeline_type == PipelineType.CHAT:
logger.debug("Selected CHAT pipeline")
return CopilotChatPipeline(self.pipeline_factory)

logger.debug("No pipeline selected")
return None
Expand Down Expand Up @@ -350,8 +356,82 @@ async def _forward_data_to_target(self, data: bytes) -> None:
pipeline_output = pipeline_output.reconstruct()
self.target_transport.write(pipeline_output)

def _has_complete_body(self) -> bool:
"""
Check if we have received the complete request body based on Content-Length header.

We check the headers from the buffer instead of using self.request.headers on purpose
because with CONNECT requests, the whole request arrives in the data and is stored in
the buffer.
"""
try:
# For the initial CONNECT request
if not self.headers_parsed and self.request and self.request.method == "CONNECT":
return True

# For subsequent requests or non-CONNECT requests, parse the method from the buffer
try:
first_line = self.buffer[: self.buffer.index(b"\r\n")].decode("utf-8")
method = first_line.split()[0]
except (ValueError, IndexError):
# Haven't received the complete request line yet
return False

if method != "POST": # do we need to check for other methods? PUT?
return True

# Parse headers from the buffer instead of using self.request.headers
headers_dict = {}
try:
headers_end = self.buffer.index(b"\r\n\r\n")
if headers_end <= 0: # Ensure we have a valid headers section
return False

headers = self.buffer[:headers_end].split(b"\r\n")
if len(headers) <= 1: # Ensure we have headers after the request line
return False

for header in headers[1:]: # Skip the request line
if not header: # Skip empty lines
continue
try:
name, value = header.decode("utf-8").split(":", 1)
headers_dict[name.strip().lower()] = value.strip()
except ValueError:
# Skip malformed headers
continue
except ValueError:
# Haven't received the complete headers yet
return False

# TODO: Add proper support for chunked transfer encoding
# For now, just pass through and let the pipeline handle it
if "transfer-encoding" in headers_dict:
return True

try:
content_length = int(headers_dict.get("content-length"))
except (ValueError, TypeError):
# Content-Length header is required for POST requests without chunked encoding
logger.error("Missing or invalid Content-Length header in POST request")
return False

body_start = headers_end + 4 # Add safety check for buffer length
if body_start >= len(self.buffer):
return False

current_body_length = len(self.buffer) - body_start
return current_body_length >= content_length
except Exception as e:
logger.error(f"Error checking body completion: {e}")
return False

def data_received(self, data: bytes) -> None:
"""Handle received data from client"""
"""
Handle received data from client. Since we need to process the complete body
through our pipeline before forwarding, we accumulate the entire request first.
"""
logger.info(f"Received data from {self.peername}: {data}")
try:
if not self._check_buffer_size(data):
self.send_error_response(413, b"Request body too large")
Expand All @@ -364,10 +444,17 @@ def data_received(self, data: bytes) -> None:
if self.headers_parsed:
if self.request.method == "CONNECT":
self.handle_connect()
self.buffer.clear()
else:
# Only process the request once we have the complete body
asyncio.create_task(self.handle_http_request())
else:
asyncio.create_task(self._forward_data_to_target(data))
if self._has_complete_body():
# Process the complete request through the pipeline
complete_request = bytes(self.buffer)
logger.debug(f"Complete request: {complete_request}")
self.buffer.clear()
asyncio.create_task(self._forward_data_to_target(complete_request))

except Exception as e:
logger.error(f"Error processing received data: {e}")
Expand Down
86 changes: 86 additions & 0 deletions tests/integration/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from abc import ABC, abstractmethod
from typing import List

import structlog
from sklearn.metrics.pairwise import cosine_similarity

from codegate.inference.inference_engine import LlamaCppInferenceEngine

logger = structlog.get_logger("codegate")


class BaseCheck(ABC):
def __init__(self, test_name: str):
self.test_name = test_name

@abstractmethod
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
pass


class CheckLoader:
@staticmethod
def load(test_data: dict) -> List[BaseCheck]:
test_name = test_data.get("name")
checks = []
if test_data.get(DistanceCheck.KEY):
checks.append(DistanceCheck(test_name))
if test_data.get(ContainsCheck.KEY):
checks.append(ContainsCheck(test_name))
if test_data.get(DoesNotContainCheck.KEY):
checks.append(DoesNotContainCheck(test_name))

return checks


class DistanceCheck(BaseCheck):
KEY = "likes"

def __init__(self, test_name: str):
super().__init__(test_name)
self.inference_engine = LlamaCppInferenceEngine()
self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf"

async def _calculate_string_similarity(self, str1, str2):
vector1 = await self.inference_engine.embed(self.embedding_model, [str1])
vector2 = await self.inference_engine.embed(self.embedding_model, [str2])
similarity = cosine_similarity(vector1, vector2)
return similarity[0]

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
similarity = await self._calculate_string_similarity(
parsed_response, test_data[DistanceCheck.KEY]
)
if similarity < 0.8:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Similarity: {similarity}")
logger.error(f"Response: {parsed_response}")
logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}")
return False
return True


class ContainsCheck(BaseCheck):
KEY = "contains"

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
if test_data[ContainsCheck.KEY].strip() not in parsed_response:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Response: {parsed_response}")
logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'")
return False
return True


class DoesNotContainCheck(BaseCheck):
KEY = "does_not_contain"

async def run_check(self, parsed_response: str, test_data: dict) -> bool:
if test_data[DoesNotContainCheck.KEY].strip() in parsed_response:
logger.error(f"Test {self.test_name} failed")
logger.error(f"Response: {parsed_response}")
logger.error(
f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'"
)
return False
return True
Loading
Loading