# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic routines for reading and writing common file types.
"""
# Standard
import csv
import json
import os
import pickle
import shutil
# Third Party
import yaml
[docs]
def load_txt(filename):
"""Load a string from a file with utf8 encoding."""
with open(filename, encoding="utf8") as fh:
return fh.read()
[docs]
def load_txt_lines(filename):
"""Load a list of files from a text file with utf8 encoding"""
with open(filename, encoding="utf8") as fh:
wordlist = list(map(str.strip, fh.readlines()))
return wordlist
[docs]
def save_txt(text, filename, mode="w"):
"""Write a string to a text file with utf8 encoding."""
with open(filename, mode=mode, encoding="utf8") as fh:
fh.write(text)
[docs]
def load_binary(filename):
"""Load a binary string from a file."""
with open(filename, mode="rb", encoding="utf-8") as fh:
return fh.read()
[docs]
def save_binary(data, filename):
"""Write a binary buffer to a file."""
with open(filename, mode="wb", encoding="utf-8") as fh:
fh.write(data)
[docs]
def load_csv(filename):
"""Load a csv into a list-of-lists."""
with open(filename, newline="", encoding="utf-8") as fh:
return list(csv.reader(fh, delimiter=",", quotechar='"'))
[docs]
def save_csv(text_list, filename, mode="w"):
"""Write a list-of-lists to a csv file."""
with open(filename, mode=mode, newline="", encoding="utf-8") as fh:
writer = csv.writer(fh, delimiter=",")
writer.writerows(text_list)
[docs]
def load_dict_csv(filename):
"""Load a csv into a list-of-dicts."""
with open(filename, encoding="utf-8") as csv_file:
csv_reader = csv.DictReader(csv_file)
return list(csv_reader)
[docs]
def save_dict_csv(dict_list, filename, mode="w"):
"""Write a list of dicts to a csv file."""
if dict_list:
keys = dict_list[0].keys()
with open(filename, mode=mode, encoding="utf-8") as output_file:
dict_writer = csv.DictWriter(output_file, keys)
dict_writer.writeheader()
dict_writer.writerows(dict_list)
[docs]
def load_json(filename):
"""Load a json file into a dictionary."""
with open(filename, encoding="utf8") as fh:
return json.load(fh)
[docs]
def save_json(save_dict, filename, mode="w"):
"""Save a dictionary into a json file."""
with open(filename, mode=mode, encoding="utf8") as fh:
json.dump(save_dict, fh, indent=2, ensure_ascii=False)
[docs]
def load_yaml(filename):
"""Load a yaml file into a dictionary."""
with open(filename, encoding="utf8") as fh:
return yaml.safe_load(fh)
[docs]
def save_yaml(save_dict, filename, mode="w"):
"""Save a dictionary into a yaml file."""
with open(filename, mode=mode, encoding="utf8") as fh:
yaml.safe_dump(save_dict, fh, default_flow_style=False)
[docs]
def load_pickle(filename):
"""Load an object from a pickle file."""
with open(filename, mode="rb") as fh:
return pickle.load(fh)
[docs]
def save_pickle(obj, filename, mode="wb"):
"""Save an object to a pickle file."""
# pylint: disable=unspecified-encoding
with open(filename, mode=mode) as fh:
pickle.dump(obj, fh)
[docs]
def save_raw(save_content, filename, mode="w"):
"""Write the given raw string content to output file."""
with open(filename, mode=mode, encoding="utf8") as fh:
fh.write(save_content)
[docs]
def compress(dir_path, output_path=None, extension="zip"):
"""Compress a given folder recursively to an archive with a given extension format
Args:
dir_path (str): Path of directory to compress
output_path: (Optional) str
Output path where the archive is created. Defaults to current path + 'archive' +
format extension
>>> compress('.', 'my/path', 'tar')
>>> # saves to 'my/path/archive.tar'
extension: (Optional) (one of: zip/tar/gztar/bztar/xztar depending on module availability)
Defaults to .zip
Returns:
str: Path to created archive
"""
if not output_path:
output_path = os.path.join(os.getcwd(), "archive")
# Strip away anything preceding '.'
extension = extension.split(".")[-1]
shutil.make_archive(output_path, extension, dir_path)
return output_path + "." + extension