BenJaminBMorin
9/11/2019 - 4:13 PM

NuGet API wrapper to search/install packages WITH AUTHENTICATION the same way Nuget.Exe/Client does it - looks up CredentialProviders and pl

NuGet API wrapper to search/install packages WITH AUTHENTICATION the same way Nuget.Exe/Client does it - looks up CredentialProviders and plugins (also for Credentials). Specific use case was Azure DevOps feed access.

using NuGet.Common;
using NuGet.Configuration;
using NuGet.Credentials;
using NuGet.PackageManagement;
using NuGet.Packaging;
using NuGet.Packaging.Core;
using NuGet.Packaging.Signing;
using NuGet.ProjectManagement;
using NuGet.Protocol;
using NuGet.Protocol.Core.Types;
using NuGet.Protocol.Plugins;
using NuGet.Resolver;
using NuGet.Versioning;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Xml.Linq;

/// <summary>
/// https://github.com/NuGet/NuGet.Client
/// References NuGet.Client/src/NuGet.Clients/NuGet.CommandLine/Commands/Command.cs
/// </summary>
/// <returns></returns>
namespace NuGetAuthGist
{
    public class NuGetManager
    {
        public SourceRepository RemoteNuGetFeed { get; private set; }
        public NuGetPackageManager LocalNuGetPackageManager { get; private set; }
        public NugetSettings Settings { get; private set; }

        private List<Lazy<INuGetResourceProvider>> providers = new List<Lazy<INuGetResourceProvider>>(Repository.Provider.GetCoreV3());

        public NuGetManager()
        {
            Settings = new NugetSettings();
            var source = Settings.UserSettings.GetSection("packageSources").Items
                .OfType<SourceItem>()
                .FirstOrDefault(item => item.Value.Contains("SomethingSpecial")) as SourceItem;

            PackageSource packageSource = new PackageSource(source.Value);
            RemoteNuGetFeed = new SourceRepository(packageSource, providers);

            PreviewFeatureSettings.DefaultCredentialsAfterCredentialProviders = true;

            HttpHandlerResourceV3.CredentialService = new Lazy<ICredentialService>(() => new CredentialService(
                new AsyncLazy<IEnumerable<ICredentialProvider>>(() => GetCredentialProvidersAsync()),
                nonInteractive: true,
                handlesDefaultCredentials: PreviewFeatureSettings.DefaultCredentialsAfterCredentialProviders));

            var packageSourceProvider = new PackageSourceProvider(Settings.Settings);
            var sourceRepositoryProvider = new SourceRepositoryProvider(packageSourceProvider, providers);
            var localPackageCachePath = Path.Combine(NuGetEnvironment.GetFolderPath(NuGetFolderPath.NuGetHome), "packages");
            LocalNuGetPackageManager = new NuGetPackageManager(sourceRepositoryProvider, Settings.Settings, localPackageCachePath)
            {
                PackagesFolderNuGetProject = new FolderNuGetProject(localPackageCachePath),
            };
        }

        #region ListPackages/Install API
        /// <summary>
        /// Search Feed
        /// </summary>
        /// <param name="searchTerm"></param>
        /// <param name="includePrelease"></param>
        /// <returns></returns>
        public async Task<Dictionary<string, List<NuGetVersion>>> ListPackagesAsync(string searchTerm, bool includePrelease)
        {
            if (RemoteNuGetFeed == null)
                InitPackageSource();

            var retVal = new Dictionary<string, List<NuGetVersion>>();
            var searchResource = await RemoteNuGetFeed.GetResourceAsync<PackageSearchResource>();
            if (searchResource != null)
            {
                var searchFilter = new SearchFilter(true) { OrderBy = SearchOrderBy.Id, IncludeDelisted = false };
                foreach (var result in await searchResource.SearchAsync(searchTerm, searchFilter, 0, 30, NullLogger.Instance, CancellationToken.None))
                    retVal.Add(result.Identity.Id, new List<NuGetVersion>((await result.GetVersionsAsync()).Select(vi => vi.Version)));
            }
            else
            {
                var listResource = await RemoteNuGetFeed.GetResourceAsync<ListResource>();
                IEnumerableAsync<IPackageSearchMetadata> allPackages = await listResource.ListAsync(searchTerm, includePrelease, true, false, NullLogger.Instance, CancellationToken.None);
                var enumerator = allPackages.GetEnumeratorAsync();
                var searchResults = new List<IPackageSearchMetadata>();
                while (true)
                {
                    var moved = await enumerator.MoveNextAsync();
                    if (!moved) break;
                    if (enumerator.Current == null) break;
                    searchResults.Add(enumerator.Current);
                }
                foreach (var result in searchResults)
                {
                    if (!retVal.ContainsKey(result.Identity.Id))
                        retVal.Add(result.Identity.Id, new List<NuGetVersion>());
                    retVal[result.Identity.Id].Add(result.Identity.Version);
                }
            }
            return retVal;
        }

        /// <summary>
        /// The files from given packages on disk
        /// </summary>
        /// <param name="packages"></param>
        /// <param name="files"></param>
        /// <returns></returns>
        public bool GetFilesFromPackages(IEnumerable<PackageIdentity> packages, out HashSet<string> files)
        {
            files = new HashSet<string>();
            foreach (var identity in packages)
            {
                try
                {
                    //doesn't work becuase on a given project they are installed w/ different path than global lib
                    //var pkg = LocalFolderUtility.GetPackageV3(rootPath, pkgIdentity, logger);
                    var pkg = LocalFolderUtility.GetPackage(new Uri(LocalNuGetPackageManager.PackagesFolderNuGetProject.GetInstalledPackageFilePath(identity)), NullLogger.Instance);
                    if (pkg == null) continue;

                    var directoryname = new FileInfo(pkg.Path).Directory.FullName;

                    var refItems = pkg.GetReader().GetReferenceItems();
                    if (refItems.Count() == 1)
                        foreach (var item in refItems.First().Items)
                            files.Add(Path.Combine(directoryname, item));
                    else
                        foreach (var refitem in refItems)
                            if (refitem.TargetFramework.Framework == ".NETFramework")
                                foreach (var item in refItems.First().Items)
                                    files.Add(Path.Combine(directoryname, item));
                }
                catch (Exception ex)
                {
                    return false;
                }
            }
            return true;
        }

        /// <summary>
        /// Install given packages + recursive dependencies
        /// </summary>
        /// <param name="packageIdentities"></param>
        /// <returns></returns>
        public async Task<IEnumerable<PackageIdentity>> TryInstallPackagesAsync(IEnumerable<PackageIdentity> packageIdentities)
        {
            var allPackagesInstalled = new HashSet<PackageIdentity>();
            await Task.WhenAll(packageIdentities.Select(async installIdentity =>
            {
                HashSet<PackageIdentity> installedPkgs = new HashSet<PackageIdentity>();
                try { installedPkgs = await InstallPackageAndDependenciesAsync(installIdentity); }
                catch (Exception ex)
                { }

                foreach (var identity in installedPkgs)
                {
                    try
                    {
                        var pkg = LocalFolderUtility.GetPackage(new Uri(LocalNuGetPackageManager.PackagesFolderNuGetProject.GetInstalledPackageFilePath(identity)), NullLogger.Instance);
                        if (pkg == null) continue;
                        var directoryname = new FileInfo(pkg.Path).Directory.FullName;
                        allPackagesInstalled.Add(identity);
                    }
                    catch (Exception ex) { }
                }
            }));
            return allPackagesInstalled;
        }

        /// <summary>
        /// Try to install the given package and dependencies
        /// </summary>
        /// <param name="identity"></param>
        /// <returns>Dependent packages that were installed AND the requested package</returns>
        private async Task<HashSet<PackageIdentity>> InstallPackageAndDependenciesAsync(PackageIdentity identity)
        {
            bool allowUnlisted = false;
            bool includePrelease = true;
            try
            {
                if (LocalNuGetPackageManager.PackageExistsInPackagesFolder(identity))
                {
                    LocalNuGetPackageManager.PackagesFolderNuGetProject.GetInstalledPath(identity);
                    var pkg = LocalFolderUtility.GetPackage(new Uri(LocalNuGetPackageManager.PackagesFolderNuGetProject.GetInstalledPackageFilePath(identity)), NullLogger.Instance);
                    foreach (var dep in pkg.GetReader().GetPackageDependencies())
                    {
                        var ids = dep.Packages.Select(pkDep => pkDep.Id);
                    }
                }
            }
            catch (Exception ex) { }
            try
            {
                ResolutionContext resolutionContext = new ResolutionContext(
                    DependencyBehavior.HighestPatch,
                    includePrelease,
                    allowUnlisted,
                    VersionConstraints.ExactMajor | VersionConstraints.ExactMinor);
                INuGetProjectContext projectContext = new BlankProjectContext(Settings.Settings, NullLogger.Instance);

                var installActions = await LocalNuGetPackageManager.PreviewInstallPackageAsync(
                    LocalNuGetPackageManager.PackagesFolderNuGetProject, identity,
                    resolutionContext, projectContext,
                    new[] { RemoteNuGetFeed }, Array.Empty<SourceRepository>(),
                    CancellationToken.None);

                var sourceCache = new SourceCacheContext();
                await LocalNuGetPackageManager.ExecuteNuGetProjectActionsAsync(LocalNuGetPackageManager.PackagesFolderNuGetProject, installActions, projectContext, sourceCache, CancellationToken.None);

                return new HashSet<PackageIdentity>(installActions.Select(action => action.PackageIdentity));
            }
            catch (Exception ex)
            {
                throw;
            }
        }
        #endregion

        #region NUGET CREDENTIAL MAGIC
        private async Task<IEnumerable<ICredentialProvider>> GetCredentialProvidersAsync()
        {
            var providers = new List<ICredentialProvider>();

            var securePluginProviders = await (new SecurePluginCredentialProviderBuilder(PluginManager.Instance, canShowDialog: false, logger: NullLogger.Instance)).BuildAllAsync();
            providers.AddRange(securePluginProviders);

            IList<ICredentialProvider> pluginProviders = new List<ICredentialProvider>();
            var extensionLocator = new ExtensionLocator();
            pluginProviders = new PluginCredentialProviderBuilder(extensionLocator, Settings.UserSettings, NullLogger.Instance)
                .BuildAll("Detailed")//Verbosity.Warning.ToString())
                .ToList();
            providers.AddRange(pluginProviders);
            if (pluginProviders.Any() || securePluginProviders.Any())
                if (PreviewFeatureSettings.DefaultCredentialsAfterCredentialProviders)
                    providers.Add(new DefaultNetworkCredentialsCredentialProvider());
            return providers;
        }
        public class ExtensionLocator : IExtensionLocator
        {
            private static readonly string ExtensionsDirectoryRoot =
                Path.Combine(Environment.GetFolderPath(
                    Environment.SpecialFolder.LocalApplicationData),
                    "NuGet",
                    "Commands");

            private static readonly string CredentialProvidersDirectoryRoot =
                Path.Combine(Environment.GetFolderPath(
                    Environment.SpecialFolder.LocalApplicationData),
                    "NuGet",
                    "CredentialProviders");

            private static readonly string CredentialProviderPattern = "CredentialProvider*.exe";

            public readonly static string ExtensionsEnvar = "NUGET_EXTENSIONS_PATH";
            public readonly static string CredentialProvidersEnvar = "NUGET_CREDENTIALPROVIDERS_PATH";

            /// <summary>
            /// Find paths to all extensions
            /// </summary>
            public IEnumerable<string> FindExtensions()
            {
                var customPaths = ReadPathsFromEnvar(ExtensionsEnvar);
                return FindAll(
                    ExtensionsDirectoryRoot,
                    customPaths,
                    "*.dll",
                    "*Extensions.dll"
                    );
            }

            /// <summary>
            /// Find paths to all credential providers
            /// </summary>
            public IEnumerable<string> FindCredentialProviders()
            {
                var customPaths = ReadPathsFromEnvar(CredentialProvidersEnvar);
                return FindAll(
                    CredentialProvidersDirectoryRoot,
                    customPaths,
                    CredentialProviderPattern,
                    CredentialProviderPattern
                    );
            }

            /// <summary>
            /// Helper method to locate extensions and credential providers.
            /// The following search locations will be checked in this order:
            /// 1) all directories under a path set by environment variable
            /// 2) all directories under a global path, e.g. %localappdata%\nuget\commands
            /// 3) the directory nuget.exe is located in
            /// </summary>
            /// <param name="globalRootDirectory">The global directory to search.  Will
            /// also check subdirectories.</param>
            /// <param name="customPaths">User-defined search paths.
            /// Will also check subdirectories.</param>
            /// <param name="assemblyPattern">The filename pattern to search for.</param>
            /// <param name="nugetDirectoryAssemblyPattern">The filename pattern to search for
            /// when looking in the nuget.exe directory. This is more restrictive so we do not
            /// accidentally pick up NuGet dlls.</param>
            /// <returns>An IEnumerable of paths to files matching the pattern in all searched
            /// directories.</returns>
            private static IEnumerable<string> FindAll(
                string globalRootDirectory,
                IEnumerable<string> customPaths,
                string assemblyPattern,
                string nugetDirectoryAssemblyPattern)
            {
                var directories = new List<string>();

                // Add all directories from the environment variable if available.
                directories.AddRange(customPaths);

                // add the global root
                directories.Add(globalRootDirectory);

                var paths = new List<string>();
                foreach (var directory in directories.Where(Directory.Exists))
                {
                    paths.AddRange(Directory.EnumerateFiles(directory, assemblyPattern, SearchOption.AllDirectories));
                }

                // Add the nuget.exe directory, but be more careful since it contains non-extension assemblies.
                // Ideally we want to look for all files. However, using MEF to identify imports results in assemblies
                // being loaded and locked by our App Domain which could be slow, might affect people's build systems
                // and among other things breaks our build.
                // Consequently, we'll use a convention - only binaries ending in the name Extensions would be loaded.
                var nugetDirectory = Path.GetDirectoryName(NuGetEnvironment.GetDotNetLocation());//Path.GetDirectoryName(typeof(Program).Assembly.Location);
                if (nugetDirectory == null)
                {
                    return paths;
                }

                paths.AddRange(Directory.EnumerateFiles(nugetDirectory, nugetDirectoryAssemblyPattern));

                return paths;
            }

            private static IEnumerable<string> ReadPathsFromEnvar(string key)
            {
                var result = new List<string>();
                var paths = Environment.GetEnvironmentVariable(key);
                if (!string.IsNullOrEmpty(paths))
                {
                    result.AddRange(
                        paths.Split(new[] { ';' },
                        StringSplitOptions.RemoveEmptyEntries));
                }
                return result;
            }
        }
        #endregion
        #region SomeClasses for Settings/Project
        public class NugetSettings : IMachineWideSettings
        {
            private readonly Lazy<ISettings> _settings;
            public ISettings Settings => _settings.Value;

            private readonly Lazy<ISettings> _userSettings;
            public ISettings UserSettings => _userSettings.Value;


            public NugetSettings()
            {
                var baseDirectory = NuGetEnvironment.GetFolderPath(NuGetFolderPath.MachineWideConfigDirectory);
                _settings = new Lazy<ISettings>(() => NuGet.Configuration.Settings.LoadMachineWideSettings(baseDirectory));

                var userDirectory = NuGetEnvironment.GetFolderPath(NuGetFolderPath.UserSettingsDirectory);
                _userSettings = new Lazy<ISettings>(() => NuGet.Configuration.Settings.LoadMachineWideSettings(userDirectory));
            }
        }

        public class BlankProjectContext : INuGetProjectContext
        {
            private NuGet.Common.ILogger Logger;

            public BlankProjectContext(ISettings settings, NuGet.Common.ILogger logger)
            {
                Logger = logger;

                var clientPolicy = ClientPolicyContext.GetClientPolicy(settings, logger);
                PackageExtractionContext = new PackageExtractionContext(PackageSaveMode.Defaultv3, XmlDocFileSaveMode.None, clientPolicy, logger);
            }

            public void Log(MessageLevel level, string message, params object[] args)
            {
                var formattedMessage = String.Format(message, args);
                switch (level)
                {
                    case MessageLevel.Info: Logger.LogInformation(formattedMessage); break;
                    case MessageLevel.Warning: Logger.LogWarning(formattedMessage); break;
                    case MessageLevel.Debug: Logger.LogDebug(formattedMessage); break;
                    case MessageLevel.Error: Logger.LogError(formattedMessage); break;
                    default: Logger.LogInformationSummary(formattedMessage); break;
                }
            }

            public FileConflictAction ResolveFileConflict(string message)
            {
                Log(MessageLevel.Warning, message);
                return FileConflictAction.Ignore;
            }

            public void ReportError(string message) => Log(MessageLevel.Error, message);

            public void Log(ILogMessage message)
            {
                var lg = message.ToString();
                switch (message.Level)
                {
                    case LogLevel.Debug:
                        Logger.LogDebug(lg); break;
                    case LogLevel.Verbose:
                        Logger.LogVerbose(lg); break;
                    case LogLevel.Information:
                        Logger.LogInformation(lg); break;
                    case LogLevel.Minimal:
                        Logger.LogMinimal(lg); break;
                    case LogLevel.Warning:
                        Logger.LogWarning(lg); break;
                    case LogLevel.Error:
                        Logger.LogError(lg); break;
                    default:
                        break;
                }
                return;
            }

            public void ReportError(ILogMessage message)
            {
                Log(message);
            }

            public PackageExtractionContext PackageExtractionContext { get; set; }
            public ISourceControlManagerProvider SourceControlManagerProvider { get; }
            public NuGet.ProjectManagement.ExecutionContext ExecutionContext { get; }
            public XDocument OriginalPackagesConfig { get; set; }
            public NuGetActionType ActionType { get; set; }
            public Guid OperationId { get; set; }
        }
        #endregion
    }
}