# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The trace module holds utilities for tracing runtime requests.
"""
# Standard
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterable, Optional, Union
import os
# Third Party
import grpc
# First Party
import alog
# Local
from ..config import get_config
from ..core.data_model.runtime_context import RuntimeServerContextType
from ..core.exceptions import error_handler
log = alog.use_channel("TRACE")
error = error_handler.get(log)
# Global handle to the trace and propagate modules that will be populated in
# configure()
_TRACE_MODULE = None
_PROPAGATE_MODULE = None
if TYPE_CHECKING:
# Third Party
from opentelemetry import Context
from opentelemetry.trace import Span, Tracer
[docs]
def get_tracer(name: str) -> Union["_NoOpProxy", "Tracer"]:
"""Get a tracer that can be called with the opentelemetry API. If not
configured, this will be a No-Op Proxy.
"""
if _TRACE_MODULE:
return _TRACE_MODULE.get_tracer(name)
return _NoOpProxy()
[docs]
def get_trace_context(runtime_context: RuntimeServerContextType) -> Optional["Context"]:
"""Extract the trace context from the runtime request context"""
if runtime_context is None or not _PROPAGATE_MODULE:
return None
if isinstance(runtime_context, grpc.ServicerContext):
return _PROPAGATE_MODULE.extract(
carrier=dict(runtime_context.invocation_metadata())
)
# Local import of fastapi as an optional dependency
try:
# Third Party
import fastapi
if isinstance(runtime_context, fastapi.Request):
return _PROPAGATE_MODULE.extract(carrier=runtime_context.headers)
except ImportError:
pass
log.debug("Unknown context type: %s", type(runtime_context))
return None
[docs]
def set_tracer(runtime_context: RuntimeServerContextType, tracer: "Tracer"):
"""Helper to decorate a runtime context with a tracer if enabled"""
if runtime_context:
setattr(runtime_context, _CONTEXT_TRACER_ATTR, tracer)
[docs]
@contextmanager
def start_child_span(
runtime_context: RuntimeServerContextType,
span_name: str,
) -> Iterable[Union["Span", "_NoOpProxy"]]:
"""Context manager that wraps start_as_current_span if enabled and tries to
fetch a parent span from the runtime context
"""
if (parent_tracer := getattr(runtime_context, _CONTEXT_TRACER_ATTR, None)) is None:
parent_tracer = get_tracer(span_name)
with parent_tracer.start_as_current_span(span_name) as span:
yield span
## Implementation Details ######################################################
_CONTEXT_TRACER_ATTR = "__tracer__"
[docs]
def _load_tls_secret(tls_config_val: str) -> bytes:
"""If the config value points at a file, load it, otherwise assume it's an
inline string
"""
if os.path.exists(tls_config_val):
with open(tls_config_val, "rb") as handle:
return handle.read()
return tls_config_val.encode("utf-8")
[docs]
class _NoOpProxy:
"""This dummy class is infinitely callable and will return itself on any
getattr call or context enter/exit. It can be used to provide a no-op
stand-in for all of the classes in the opentelemetry ecosystem when they are
either not configured or not available.
"""
[docs]
def __getattr__(self, *_, **__):
return self
[docs]
def __call__(self, *_, **__) -> "_NoOpProxy":
return self
[docs]
def __enter__(self, *_, **__) -> "_NoOpProxy":
return self
[docs]
def __exit__(self, *_, **__):
pass